aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorGravatar Jason Gross <jgross@mit.edu>2018-03-21 17:29:16 -0400
committerGravatar Jason Gross <jasongross9@gmail.com>2018-04-04 15:39:34 -0400
commit8e17c3d75ce9cb9d2c0c3921514e9318776a28de (patch)
tree78221eefcbc3bab08c4070c5d4e3aba559091ef0 /src
parentc900290d3297ade2cc2e73fe6b322abe52d1715a (diff)
Stick an uncurry pass in the pipeline
This allows us to (a) consolidate the constant and non-constant pipelines and (b) vastly simplify the call-with-id-continuation logic.
Diffstat (limited to 'src')
-rw-r--r--src/Experiments/SimplyTypedArithmetic.v729
1 files changed, 155 insertions, 574 deletions
diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v
index 202e21098..a8a98c37e 100644
--- a/src/Experiments/SimplyTypedArithmetic.v
+++ b/src/Experiments/SimplyTypedArithmetic.v
@@ -501,11 +501,10 @@ Section mod_ops.
Qed.
Derive carry_mulmod
- SuchThat (forall (fg : list Z * list Z)
- (f := fst fg) (g := snd fg)
+ SuchThat (forall (f g : list Z)
(Hf : length f = n)
(Hg : length g = n),
- (eval weight n (carry_mulmod fg)) mod (s - Associational.eval c)
+ (eval weight n (carry_mulmod f g)) mod (s - Associational.eval c)
= (eval weight n f * eval weight n g) mod (s - Associational.eval c))
As eval_carry_mulmod.
Proof.
@@ -516,7 +515,7 @@ Section mod_ops.
by auto; reflexivity ].
eapply f_equal2; [|trivial]. eapply f_equal.
expand_lists ().
- subst f g carry_mulmod; reflexivity.
+ subst carry_mulmod; reflexivity.
Qed.
Derive carrymod
@@ -536,11 +535,10 @@ Section mod_ops.
Qed.
Derive addmod
- SuchThat (forall (fg: list Z * list Z)
- (f := fst fg) (g := snd fg)
+ SuchThat (forall (f g : list Z)
(Hf : length f = n)
(Hg : length g = n),
- (eval weight n (addmod fg)) mod (s - Associational.eval c)
+ (eval weight n (addmod f g)) mod (s - Associational.eval c)
= (eval weight n f + eval weight n g) mod (s - Associational.eval c))
As eval_addmod.
Proof.
@@ -548,16 +546,15 @@ Section mod_ops.
rewrite <-eval_add by auto.
eapply f_equal2; [|trivial]. eapply f_equal.
expand_lists ().
- subst f g addmod; reflexivity.
+ subst addmod; reflexivity.
Qed.
Derive submod
SuchThat (forall (coef:Z)
- (fg: list Z * list Z)
- (f := fst fg) (g := snd fg)
+ (f g : list Z)
(Hf : length f = n)
(Hg : length g = n),
- (eval weight n (submod coef fg)) mod (s - Associational.eval c)
+ (eval weight n (submod coef f g)) mod (s - Associational.eval c)
= (eval weight n f - eval weight n g) mod (s - Associational.eval c))
As eval_submod.
Proof.
@@ -565,7 +562,7 @@ Section mod_ops.
rewrite <-eval_sub with (coef:=coef) by auto.
eapply f_equal2; [|trivial]. eapply f_equal.
expand_lists ().
- subst f g submod; reflexivity.
+ subst submod; reflexivity.
Qed.
Derive oppmod
@@ -1634,17 +1631,17 @@ Module Ring.
is_bounded_by tight_bounds arg = true
-> is_bounded_by loose_bounds (Interp_rrelaxv arg) = true
/\ Interp_rrelaxv arg = expanding_id n arg)
- (carry_mulmod : list Z * list Z -> list Z)
+ (carry_mulmod : list Z -> list Z -> list Z)
(Hcarry_mulmod
- : forall fg,
- length (fst fg) = n -> length (snd fg) = n ->
- (eval (carry_mulmod fg)) mod (s - Associational.eval c)
- = (eval (fst fg) * eval (snd fg)) mod (s - Associational.eval c))
+ : forall f g,
+ length f = n -> length g = n ->
+ (eval (carry_mulmod f g)) mod (s - Associational.eval c)
+ = (eval f * eval g) mod (s - Associational.eval c))
(Interp_rcarry_mulv : list Z * list Z -> list Z)
(HInterp_rcarry_mulv : forall arg,
is_bounded_by2 loose_bounds arg = true
-> is_bounded_by tight_bounds (Interp_rcarry_mulv arg) = true
- /\ Interp_rcarry_mulv arg = carry_mulmod arg)
+ /\ Interp_rcarry_mulv arg = carry_mulmod (fst arg) (snd arg))
(carrymod : list Z -> list Z)
(Hcarrymod
: forall f,
@@ -1656,28 +1653,28 @@ Module Ring.
is_bounded_by loose_bounds arg = true
-> is_bounded_by tight_bounds (Interp_rcarryv arg) = true
/\ Interp_rcarryv arg = carrymod arg)
- (addmod : list Z * list Z -> list Z)
+ (addmod : list Z -> list Z -> list Z)
(Haddmod
- : forall fg,
- length (fst fg) = n -> length (snd fg) = n ->
- (eval (addmod fg)) mod (s - Associational.eval c)
- = (eval (fst fg) + eval (snd fg)) mod (s - Associational.eval c))
+ : forall f g,
+ length f = n -> length g = n ->
+ (eval (addmod f g)) mod (s - Associational.eval c)
+ = (eval f + eval g) mod (s - Associational.eval c))
(Interp_raddv : list Z * list Z -> list Z)
(HInterp_raddv : forall arg,
is_bounded_by2 tight_bounds arg = true
-> is_bounded_by loose_bounds (Interp_raddv arg) = true
- /\ Interp_raddv arg = addmod arg)
- (submod : list Z * list Z -> list Z)
+ /\ Interp_raddv arg = addmod (fst arg) (snd arg))
+ (submod : list Z -> list Z -> list Z)
(Hsubmod
- : forall fg,
- length (fst fg) = n -> length (snd fg) = n ->
- (eval (submod fg)) mod (s - Associational.eval c)
- = (eval (fst fg) - eval (snd fg)) mod (s - Associational.eval c))
+ : forall f g,
+ length f = n -> length g = n ->
+ (eval (submod f g)) mod (s - Associational.eval c)
+ = (eval f - eval g) mod (s - Associational.eval c))
(Interp_rsubv : list Z * list Z -> list Z)
(HInterp_rsubv : forall arg,
is_bounded_by2 tight_bounds arg = true
-> is_bounded_by loose_bounds (Interp_rsubv arg) = true
- /\ Interp_rsubv arg = submod arg)
+ /\ Interp_rsubv arg = submod (fst arg) (snd arg))
(oppmod : list Z -> list Z)
(Hoppmod
: forall f,
@@ -1773,10 +1770,12 @@ Module Ring.
| [ |- _ = _ :> Z ] => first [ reflexivity | rewrite <- m_eq; reflexivity ]
| [ H : context[?x] |- Fdecode ?x = _ ] => rewrite H
| [ H : context[?x _] |- Fdecode (?x _) = _ ] => rewrite H
+ | [ H : context[?x _ _] |- Fdecode (?x _ _) = _ ] => rewrite H
| _ => progress cbv [Fdecode]
| [ |- _ = _ :> F _ ] => apply F.eq_to_Z_iff
| _ => progress autorewrite with push_FtoZ
| _ => rewrite m_eq
+ | [ H : context[?x _ _] |- context[eval (?x _ _)] ] => rewrite H
| [ H : context[?x _] |- context[eval (?x _)] ] => rewrite H
| [ H : context[?x] |- context[eval ?x] ] => rewrite H
| [ |- context[List.length ?x] ]
@@ -4187,347 +4186,69 @@ Module Compilers.
: @Compilers.Uncurried.expr.default.Expr R
:= expr.CallWithContinuation (@ident.untranslate _) (@ident.fst) (@ident.snd) e k.
- (** It's not clear how to "plug in the identity continuation"
- for the CPS'd form of an expression of type [((A -> B) -> C)
- -> D]. So we must describe types of at most second order
- functions, so that we can write a uniform "plug in the
- identity continuation" transformation. *)
- Module second_order.
- Import Compilers.type.
- Module Import type.
- Inductive flat_type :=
- | type_primitive (_ : primitive)
- | prod (_ : flat_type) (_ : flat_type)
- | list (_ : flat_type).
- Inductive argtype :=
- | flat_arg (_ : flat_type)
- | arrow_arg (s : flat_type) (d : argtype)
- | prod_arg (_ _ : argtype).
- Inductive type :=
- | flat (_ : flat_type)
- | arrow (s : argtype) (d : type).
- End type.
-
- Module Export Coercions.
- Coercion type_primitive : primitive >-> flat_type.
- Coercion flat_arg : flat_type >-> argtype.
- Coercion flat : flat_type >-> type.
- End Coercions.
- Notation flat_type := flat_type.
- Notation argtype := argtype.
- Notation type := type.
-
- Fixpoint flat_to_type (t : flat_type) : Compilers.type.type
- := match t with
- | type_primitive x => x
- | prod A B => Compilers.type.prod (flat_to_type A) (flat_to_type B)
- | list A => Compilers.type.list (flat_to_type A)
- end.
-
- Fixpoint arg_to_type (t : argtype) : Compilers.type.type
- := match t with
- | flat_arg t => flat_to_type t
- | arrow_arg s d => Compilers.type.arrow (flat_to_type s) (arg_to_type d)
- | prod_arg A B => Compilers.type.prod (arg_to_type A) (arg_to_type B)
- end.
-
- Fixpoint to_type (t : type) : Compilers.type.type
- := match t with
- | flat t => flat_to_type t
- | arrow s d => Compilers.type.arrow (arg_to_type s) (to_type d)
- end.
-
- Fixpoint flat_of_type (t : Compilers.type.type) : option flat_type
- := match t with
- | Compilers.type.type_primitive x => @Some flat_type x
- | Compilers.type.prod A B
- => match flat_of_type A, flat_of_type B with
- | Some A, Some B => @Some flat_type (prod A B)
- | _, _ => None
- end
- | Compilers.type.arrow s d => None
- | Compilers.type.list A
- => option_map list (flat_of_type A)
- end.
-
- Fixpoint arg_of_type (t : Compilers.type.type) : option argtype
- := match t with
- | Compilers.type.prod A B as t
- => match flat_of_type t, arg_of_type A, arg_of_type B with
- | Some t, _, _
- => @Some argtype t
- | None, Some A, Some B
- => @Some argtype (prod_arg A B)
- | _, _, _ => None
- end
- | Compilers.type.arrow s d
- => match flat_of_type s, arg_of_type d with
- | Some s, Some d => Some (arrow_arg s d)
- | _, _ => None
- end
- | Compilers.type.type_primitive _ as t
- | Compilers.type.list _ as t
- => option_map flat_arg (flat_of_type t)
- end.
-
- Fixpoint of_type (t : Compilers.type.type) : option type
- := match t with
- | Compilers.type.arrow s d
- => match arg_of_type s, of_type d with
- | Some s, Some d => Some (arrow s d)
- | _, _ => None
- end
- | Compilers.type.prod _ _ as t
- | Compilers.type.type_primitive _ as t
- | Compilers.type.list _ as t
- => option_map flat (flat_of_type t)
- end.
-
- Fixpoint try_transport_flat_of_type P (t : Compilers.type.type)
- : P t -> option { t' : _ & P (flat_to_type t') }
- := match t with
- | Compilers.type.type_primitive x
- => fun v => Some (existT _ (x : flat_type) v)
- | Compilers.type.prod A B
- => fun v
- => match try_transport_flat_of_type (fun a => P (a * _)%ctype) A v with
- | Some (existT A v)
- => match try_transport_flat_of_type (fun b => P (_ * b)%ctype) B v with
- | Some (existT B v)
- => Some (existT _ (prod A B) v)
- | None => None
- end
- | None => None
- end
- | Compilers.type.arrow s d => fun _ => None
- | Compilers.type.list A
- => fun v
- => option_map
- (fun '(existT A v) => existT (fun t => P (flat_to_type t)) (list A) v)
- (try_transport_flat_of_type (fun a => P (Compilers.type.list a)) A v)
- end.
-
- Fixpoint try_transport_arg_of_type P (t : Compilers.type.type)
- : P t -> option { t' : _ & P (arg_to_type t') }
- := match t with
- | Compilers.type.prod A B as t
- => fun v
- => match try_transport_flat_of_type P t v with
- | Some (existT t v) => Some (existT (fun t' => P (arg_to_type t')) t v)
- | None
- => match try_transport_arg_of_type (fun a => P (a * _)%ctype) A v with
- | Some (existT A v)
- => match try_transport_arg_of_type (fun b => P (_ * b)%ctype) B v with
- | Some (existT B v)
- => Some (existT _ (prod_arg A B) v)
- | None => None
- end
- | None => None
- end
- end
- | Compilers.type.arrow s d
- => fun v
- => match try_transport_flat_of_type (fun s => P (s -> _)%ctype) s v with
- | Some (existT s v)
- => match try_transport_flat_of_type (fun d => P (_ -> d)%ctype) d v with
- | Some (existT d v)
- => Some (existT (fun t' => P (arg_to_type t')) (arrow_arg s d) v)
- | None => None
- end
- | None => None
- end
- | Compilers.type.type_primitive _ as t
- | Compilers.type.list _ as t
- => fun v
- => option_map
- (fun '(existT t v) => existT (fun t => P (arg_to_type t)) (flat_arg t) v)
- (try_transport_flat_of_type P t v)
- end.
-
- Fixpoint try_transport_of_type P (t : Compilers.type.type)
- : P t -> option { t' : _ & P (to_type t') }
- := match t with
- | Compilers.type.arrow s d
- => fun v
- => match try_transport_arg_of_type (fun s => P (s -> _)%ctype) s v with
- | Some (existT s v)
- => match try_transport_of_type (fun d => P (_ -> d)%ctype) d v with
- | Some (existT d v)
- => Some (existT (fun t' => P (to_type t')) (arrow s d) v)
- | None => None
- end
- | None => None
- end
- | Compilers.type.prod _ _ as t
- | Compilers.type.type_primitive _ as t
- | Compilers.type.list _ as t
- => fun v
- => option_map
- (fun '(existT t v) => existT (fun t => P (to_type t)) (flat t) v)
- (try_transport_flat_of_type P t v)
- end.
- End second_order.
- Import second_order.Coercions.
-
- Fixpoint untranslate_translate_flat
- (P : Compilers.type.type -> Type)
- {R}
- {t : second_order.flat_type}
- (e : P (second_order.to_type t))
- {struct t}
- : P (type.untranslate R (type.translate (second_order.to_type t)))
- := match t return P (second_order.to_type t)
- -> P (type.untranslate R (type.translate (second_order.to_type t)))
- with
- | second_order.type.type_primitive x => id
- | second_order.type.prod A B
- => fun ab : P (second_order.flat_to_type A * second_order.flat_to_type B)%ctype
- => @untranslate_translate_flat
- (fun a => P (a * _)%ctype)
- R A
- (@untranslate_translate_flat
- (fun b => P (_ * b)%ctype)
- R B
- ab)
- | second_order.type.list A
- => @untranslate_translate_flat
- (fun t => P (Compilers.type.list t))
- R A
- end e.
-
- Fixpoint untranslate_translate_flat'
- (P : Compilers.type.type -> Type)
- {R}
- {t : second_order.flat_type}
- (e : P (type.untranslate R (type.translate (second_order.to_type t))))
- {struct t}
- : P (second_order.to_type t)
- := match t return P (type.untranslate R (type.translate (second_order.to_type t)))
- -> P (second_order.to_type t)
- with
- | second_order.type.type_primitive x => id
- | second_order.type.prod A B
- => fun ab :
- (* ignore this line *) P (type.untranslate R (type.translate (second_order.flat_to_type A)) * type.untranslate R (type.translate (second_order.flat_to_type B)))%ctype
- => @untranslate_translate_flat'
- (fun a => P (a * _)%ctype)
- R A
- (@untranslate_translate_flat'
- (fun b => P (_ * b)%ctype)
- R B
- ab)
- | second_order.type.list A
- => @untranslate_translate_flat'
- (fun t => P (Compilers.type.list t))
- R A
- end e.
-
- Definition transport_final_codomain_flat P {t}
- : P (second_order.flat_to_type t)
- -> P (type.final_codomain (second_order.flat_to_type t))
- := match t with
- | second_order.type.type_primitive x => id
- | second_order.type.prod x x0 => id
- | second_order.type.list x => id
- end.
-
- Definition transport_final_codomain_flat' P {t}
- : P (type.final_codomain (second_order.flat_to_type t))
- -> P (second_order.flat_to_type t)
- := match t with
- | second_order.type.type_primitive x => id
- | second_order.type.prod x x0 => id
- | second_order.type.list x => id
+ Local Notation iffT A B := ((A -> B) * (B -> A))%type.
+ (** We can only "plug in the identity continuation" for flat
+ (arrow-free) types. (Actually, we know how to do it in a
+ very ad-hoc way for types of at-most second-order functions;
+ see git history. This is much simpler.) *)
+ Fixpoint try_untranslate_translate {R} {t}
+ : option (forall (P : Compilers.type.type -> Type),
+ iffT (P (type.untranslate R (type.translate t))) (P t))
+ := match t return option (forall (P : Compilers.type.type -> Type),
+ iffT (P (type.untranslate R (type.translate t))) (P t)) with
+ | Compilers.type.type_primitive x
+ => Some (fun P => ((fun v => v), (fun v => v)))
+ | type.arrow s d => None
+ | Compilers.type.prod A B
+ => (fA <- (@try_untranslate_translate _ A);
+ fB <- (@try_untranslate_translate _ B);
+ Some
+ (fun P
+ => let fA := fA (fun A => P (Compilers.type.prod A (type.untranslate R (type.translate B)))) in
+ let fB := fB (fun B => P (Compilers.type.prod A B)) in
+ ((fun v => fst fB (fst fA v)),
+ (fun v => snd fA (snd fB v)))))%option
+ | Compilers.type.list A
+ => (fA <- (@try_untranslate_translate R A);
+ Some (fun P => fA (fun A => P (Compilers.type.list A))))%option
end.
- Fixpoint untranslate_translate_arg
- {var}
- {R}
- {t : second_order.argtype}
- (e : @Compilers.Uncurried.expr.default.expr var (second_order.arg_to_type t))
- {struct t}
- : @Compilers.Uncurried.expr.default.expr var (type.untranslate R (type.translate (second_order.arg_to_type t)))
- := match t return Compilers.Uncurried.expr.default.expr (second_order.arg_to_type t)
- -> Compilers.Uncurried.expr.default.expr (type.untranslate R (type.translate (second_order.arg_to_type t)))
- with
- | second_order.type.flat_arg t
- => untranslate_translate_flat _
- | second_order.type.arrow_arg s d
- => fun e'
- => Abs (fun v :
- (* ignore this line *) var (type.untranslate R (type.translate (second_order.flat_to_type s)) * (type.untranslate R (type.translate (second_order.arg_to_type d)) -> R))%ctype
- => (ident.snd @@ Var v)
- @ (@untranslate_translate_arg
- var R d
- (e' @ (untranslate_translate_flat' _ (ident.fst @@ Var v)))))%expr
- | second_order.type.prod_arg A B
- => fun e' : expr.default.expr (second_order.arg_to_type A * second_order.arg_to_type B)
- => ((Abs (fun a => Abs (fun b => (Var a, Var b))))
- @ (@untranslate_translate_arg var R A (ident.fst @@ e'))
- @ (@untranslate_translate_arg var R B (ident.snd @@ e')))%expr
- end e.
-
Local Notation "x <-- e1 ; e2" := (expr.splice e1 (fun x => e2%cpsexpr)) : cpsexpr_scope.
- Fixpoint call_fun_with_id_continuation'
- {var}
- {t : second_order.type}
- (R := type.final_codomain (second_order.to_type t))
- (e : @expr (fun t0 =>
- @Uncurried.expr.expr default.ident.ident var (type.untranslate R t0))
- (type.translate (second_order.to_type t)))
- {struct t}
- : @Compilers.Uncurried.expr.default.expr var (second_order.to_type t)
- := match t
- return (@expr (fun t0 =>
- @Uncurried.expr.expr default.ident.ident var (type.untranslate (type.final_codomain (second_order.to_type t)) t0))
- (type.translate (second_order.to_type t)))
- -> @Compilers.Uncurried.expr.default.expr var (second_order.to_type t)
- with
- | second_order.type.flat t
- => fun e'
- => transport_final_codomain_flat'
- _
- (@call_with_continuation
- var _ _ e'
- (fun e'' => transport_final_codomain_flat _ (untranslate_translate_flat' _ e'')))
- | second_order.type.arrow s d
- => fun e' :
- (* ignore this line *) expr (type.translate (second_order.arg_to_type s) * (type.translate (second_order.to_type d) --->) --->)
- => Abs (s:=second_order.arg_to_type s) (d:=second_order.to_type d)
- (fun v
- => @call_fun_with_id_continuation'
- var d
- (f <-- e';
+ Definition call_fun_with_id_continuation'
+ {s d}
+ : option (forall var
+ (e : @expr _ (type.translate (s -> d))),
+ @Compilers.Uncurried.expr.default.expr var (s -> d))
+ := (fs <- (@try_untranslate_translate _ s);
+ fd <- (@try_untranslate_translate _ d);
+ Some
+ (fun var e
+ => let P := @Compilers.Uncurried.expr.default.expr var in
+ Abs
+ (fun v : var s
+ => call_with_continuation
+ ((f <-- e;
k <- (λ r, expr.Halt r);
- p <- (untranslate_translate_arg (Var v), k);
+ p <- (snd (fs P) (Var v), k);
f @ p)%cpsexpr)
- end e.
- Definition CallFunWithIdContinuation'
- {t : second_order.type}
- (e : Expr (type.translate (second_order.to_type t)))
- : @Compilers.Uncurried.expr.default.Expr (second_order.to_type t)
- := fun var => @call_fun_with_id_continuation' _ t (e _).
+ (fst (fd P)))))%option.
- Definition CallFunWithIdContinuation
- {t}
- (e : Expr (type.translate t))
- := match second_order.try_transport_of_type (fun t => Expr (type.translate t)) _
- e
- as o return match o with None => _ | _ => _ end
- with
- | Some v => CallFunWithIdContinuation' (projT2 v)
- | None => I
- end.
+ Definition call_fun_with_id_continuation
+ {var}
+ {s d} (e : @expr _ (type.translate (s -> d)))
+ : option (@Compilers.Uncurried.expr.default.expr var (s -> d))
+ := option_map
+ (fun f => f _ e)
+ (@call_fun_with_id_continuation' s d).
- Definition CallFunWithIdContinuation_opt
- {t}
- (e : Expr (type.translate t))
- : option (@Compilers.Uncurried.expr.default.Expr t)
- := (e' <- (second_order.try_transport_of_type
- (fun t => Expr (type.translate t)) _
- e);
- type.try_transport _ _ _ (CallFunWithIdContinuation' (projT2 e')))%option.
+ Definition CallFunWithIdContinuation
+ {s d}
+ (e : Expr (type.translate (s -> d)))
+ : option (@Compilers.Uncurried.expr.default.Expr (s -> d))
+ := option_map
+ (fun f var => f _ (e _))
+ (@call_fun_with_id_continuation' s d).
End default.
Include default.
End CPS.
@@ -6404,19 +6125,20 @@ Module test3.
(z * z)) in
pose v as E.
vm_compute in E.
- pose (PartialEvaluate false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
+ pose (option_map (PartialEvaluate false) (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
vm_compute in E'.
lazymatch (eval cbv delta [E'] in E') with
- | (fun var : type -> Type =>
- (λ x : var (type.type_primitive type.Z),
- expr_let x0 := Var x * Var x in
- expr_let x1 := Var x0 * Var x0 in
- expr_let x2 := Var x1 * Var x1 in
- expr_let x3 := Var x2 * Var x2 in
- Var x3 * Var x3)%expr)
+ | (Some
+ (fun var : type -> Type =>
+ (λ x : var (type.type_primitive type.Z),
+ expr_let x0 := Var x * Var x in
+ expr_let x1 := Var x0 * Var x0 in
+ expr_let x2 := Var x1 * Var x1 in
+ expr_let x3 := Var x2 * Var x2 in
+ Var x3 * Var x3)%expr))
=> idtac
end.
- pose (PartialEvaluateWithBounds1 E' (Some r[0~>10]%zrange)) as E'''.
+ pose (PartialEvaluateWithBounds1 (invert_Some E') (Some r[0~>10]%zrange)) as E'''.
lazy in E'''.
lazymatch (eval cbv delta [E'''] in E''') with
| (fun var : type -> Type =>
@@ -6441,10 +6163,10 @@ Module test4.
(xz :: xz :: nil)) in
pose v as E.
vm_compute in E.
- pose (PartialEvaluate false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
+ pose (option_map (PartialEvaluate false) (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))) as E'.
lazy in E'.
clear E.
- pose (PartialEvaluateWithBounds1 E' (Some [Some r[0~>10]%zrange],Some [Some r[0~>10]%zrange])) as E''.
+ pose (PartialEvaluateWithBounds1 (invert_Some E') (Some [Some r[0~>10]%zrange],Some [Some r[0~>10]%zrange])) as E''.
lazy in E''.
lazymatch (eval cbv delta [E''] in E'') with
| (fun var : type -> Type =>
@@ -6468,7 +6190,7 @@ Module test5.
x) in
pose v as E.
vm_compute in E.
- pose (ReassociateSmallConstants.Reassociate (2^8) (PartialEvaluate false (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))))) as E'.
+ pose (ReassociateSmallConstants.Reassociate (2^8) (PartialEvaluate false (invert_Some (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E)))))) as E'.
lazy in E'.
clear E.
lazymatch (eval cbv delta [E'] in E') with
@@ -6495,7 +6217,7 @@ Module test6.
pose (CPS.CallFunWithIdContinuation (CPS.Translate (canonicalize_list_recursion E))) as E'.
lazy in E'.
clear E.
- pose (PartialEvaluate false E') as E''.
+ pose (PartialEvaluate false (invert_Some E')) as E''.
lazy in E''.
lazymatch eval cbv delta [E''] in E'' with
| fun var : type -> Type => (λ x : var (type.type_primitive type.Z), Var x)%expr
@@ -6650,7 +6372,7 @@ Create HintDb reify_gen_cache.
Derive carry_mul_gen
SuchThat (forall (limbwidth_num limbwidth_den : Z)
- (fg : list Z * list Z)
+ (f g : list Z)
(n : nat)
(s : Z)
(c : list (Z * Z))
@@ -6658,11 +6380,11 @@ Derive carry_mul_gen
(idxs : list nat)
(len_idxs : nat),
Interp (t:=type.reify_type_of carry_mulmod)
- carry_mul_gen limbwidth_num limbwidth_den s c n len_c idxs len_idxs fg
- = carry_mulmod limbwidth_num limbwidth_den s c n len_c idxs len_idxs fg)
+ carry_mul_gen limbwidth_num limbwidth_den s c n len_c idxs len_idxs f g
+ = carry_mulmod limbwidth_num limbwidth_den s c n len_c idxs len_idxs f g)
As carry_mul_gen_correct.
Proof. Time cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Time Qed.
-Hint Extern 1 (_ = carry_mulmod _ _ _ _ _ _ _ _ _) => simple apply carry_mul_gen_correct : reify_gen_cache.
+Hint Extern 1 (_ = carry_mulmod _ _ _ _ _ _ _ _ _ _) => simple apply carry_mul_gen_correct : reify_gen_cache.
Derive carry_gen
SuchThat (forall (limbwidth_num limbwidth_den : Z)
@@ -6696,14 +6418,14 @@ Hint Extern 1 (_ = encodemod _ _ _ _ _ _ _) => simple apply encode_gen_correct :
Derive add_gen
SuchThat (forall (limbwidth_num limbwidth_den : Z)
- (fg : list Z * list Z)
+ (f g : list Z)
(n : nat),
Interp (t:=type.reify_type_of addmod)
- add_gen limbwidth_num limbwidth_den n fg
- = addmod limbwidth_num limbwidth_den n fg)
+ add_gen limbwidth_num limbwidth_den n f g
+ = addmod limbwidth_num limbwidth_den n f g)
As add_gen_correct.
Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed.
-Hint Extern 1 (_ = addmod _ _ _ _) => simple apply add_gen_correct : reify_gen_cache.
+Hint Extern 1 (_ = addmod _ _ _ _ _) => simple apply add_gen_correct : reify_gen_cache.
Derive sub_gen
SuchThat (forall (limbwidth_num limbwidth_den : Z)
(n : nat)
@@ -6711,13 +6433,13 @@ Derive sub_gen
(c : list (Z * Z))
(len_c : nat)
(coef : Z)
- (fg : list Z * list Z),
+ (f g : list Z),
Interp (t:=type.reify_type_of submod)
- sub_gen limbwidth_num limbwidth_den s c n len_c coef fg
- = submod limbwidth_num limbwidth_den s c n len_c coef fg)
+ sub_gen limbwidth_num limbwidth_den s c n len_c coef f g
+ = submod limbwidth_num limbwidth_den s c n len_c coef f g)
As sub_gen_correct.
Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed.
-Hint Extern 1 (_ = submod _ _ _ _ _ _ _ _) => simple apply sub_gen_correct : reify_gen_cache.
+Hint Extern 1 (_ = submod _ _ _ _ _ _ _ _ _) => simple apply sub_gen_correct : reify_gen_cache.
Derive opp_gen
SuchThat (forall (limbwidth_num limbwidth_den : Z)
@@ -6773,6 +6495,7 @@ Derive id_gen
Proof. cache_reify (). exact admit. (* correctness of initial parts of the pipeline *) Qed.
Hint Extern 1 (_ = expanding_id _ _) => simple apply id_gen_correct : reify_gen_cache.
+Import Uncurry.
Module Pipeline.
Import GeneralizeVar.
Inductive ErrorMessage :=
@@ -6808,16 +6531,17 @@ Module Pipeline.
expr.Interp (@for_reification.ident.interp) E.
Admitted.
- Definition BoundsPipeline_0_or_1
+ Definition BoundsPipeline
(with_dead_code_elimination : bool := true)
(with_subst01 : bool)
+ relax_zrange
{t}
(E : Expr t)
- {t'}
- (CheckedPartialEvaluateWithBounds : _ -> _ -> _ + partial.data t')
+ arg_bounds
out_bounds
- : ErrorT (Expr t)
- := let E := CPS.CallFunWithIdContinuation_opt (CPS.Translate E) in
+ : ErrorT (Expr (type.uncurry t))
+ := let E := expr.Uncurry E in
+ let E := CPS.CallFunWithIdContinuation (CPS.Translate E) in
match E with
| Some E
=> (let E := PartialEvaluate false E in
@@ -6832,7 +6556,7 @@ Module Pipeline.
let E := FromFlat e in
let E := if with_subst01 then Subst01.Subst01 E else E in
let E := ReassociateSmallConstants.Reassociate (2^8) E in
- let E := CheckedPartialEvaluateWithBounds E out_bounds in
+ let E := CheckedPartialEvaluateWithBounds1 relax_zrange E arg_bounds out_bounds in
match E with
| inl E => Success E
| inr b
@@ -6841,45 +6565,14 @@ Module Pipeline.
| None => Error (Type_too_complicated_for_cps t)
end.
- Definition BoundsPipeline
- (with_dead_code_elimination : bool := true)
- (with_subst01 : bool)
- relax_zrange
- {s d}
- (E : Expr (s -> d))
- arg_bounds
- out_bounds
- : ErrorT (Expr (s -> d))
- := BoundsPipeline_0_or_1
- (*with_dead_code_elimination*)
- with_subst01
- E
- (fun E => CheckedPartialEvaluateWithBounds1 relax_zrange E arg_bounds)
- out_bounds.
-
- Definition BoundsPipelineConst
- (with_dead_code_elimination : bool := true)
- (with_subst01 : bool)
- relax_zrange
- {t}
- (E : Expr t)
- out_bounds
- : ErrorT (Expr t)
- := BoundsPipeline_0_or_1
- (*with_dead_code_elimination*)
- with_subst01
- E
- (fun E => CheckedPartialEvaluateWithBounds0 relax_zrange E)
- out_bounds.
-
Lemma BoundsPipeline_correct
(with_dead_code_elimination : bool := true)
(with_subst01 : bool)
relax_zrange
(Hrelax : forall r r' z : zrange,
(z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true)
- {s d}
- (e : Expr (s -> d))
+ {t}
+ (e : Expr t)
arg_bounds
out_bounds
rv
@@ -6887,9 +6580,9 @@ Module Pipeline.
: forall arg
(Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true),
ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true
- /\ Interp rv arg = Interp e arg.
+ /\ Interp rv arg = app_curried (Interp e) arg.
Proof.
- cbv [BoundsPipeline BoundsPipeline_0_or_1 Let_In] in *;
+ cbv [BoundsPipeline Let_In] in *;
repeat match goal with
| [ H : match ?x with _ => _ end = Success _ |- _ ]
=> destruct x eqn:?; cbv beta iota in H; [ | congruence ];
@@ -6913,15 +6606,15 @@ Module Pipeline.
Qed.
Definition BoundsPipeline_correct_transT
- {s d}
+ {t}
arg_bounds
out_bounds
- (InterpE : type.interp s -> type.interp d)
- (rv : Expr (s -> d))
+ (InterpE : type.interp t)
+ (rv : Expr (type.uncurry t))
:= forall arg
(Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true),
ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true
- /\ Interp rv arg = InterpE arg.
+ /\ Interp rv arg = app_curried InterpE arg.
Lemma BoundsPipeline_correct_trans
(with_dead_code_elimination : bool := true)
@@ -6930,14 +6623,14 @@ Module Pipeline.
(Hrelax
: forall r r' z : zrange,
(z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true)
- {s d}
- (e : Expr (s -> d))
+ {t}
+ (e : Expr t)
arg_bounds out_bounds
- (InterpE : type.interp s -> type.interp d)
+ (InterpE : type.interp t)
(InterpE_correct
: forall arg
(Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true),
- Interp e arg = InterpE arg)
+ app_curried (Interp e) arg = app_curried InterpE arg)
rv
(Hrv : BoundsPipeline (*with_dead_code_elimination*) with_subst01 relax_zrange e arg_bounds out_bounds = Success rv)
: BoundsPipeline_correct_transT arg_bounds out_bounds InterpE rv.
@@ -6950,17 +6643,17 @@ Module Pipeline.
(with_dead_code_elimination : bool := true)
(with_subst01 : bool)
relax_zrange
- {s d}
- (E : for_reification.Expr (s -> d))
+ {t}
+ (E : for_reification.Expr t)
arg_bounds
out_bounds
- : ErrorT (Expr (s -> d))
+ : ErrorT (Expr (type.uncurry t))
:= let E := PrePipeline E in
@BoundsPipeline
(*with_dead_code_elimination*)
with_subst01
relax_zrange
- s d E arg_bounds out_bounds.
+ t E arg_bounds out_bounds.
Lemma BoundsPipeline_full_correct
(with_dead_code_elimination : bool := true)
@@ -6968,8 +6661,8 @@ Module Pipeline.
relax_zrange
(Hrelax : forall r r' z : zrange,
(z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true)
- {s d}
- (E : for_reification.Expr (s -> d))
+ {t}
+ (E : for_reification.Expr t)
arg_bounds
out_bounds
rv
@@ -6977,111 +6670,12 @@ Module Pipeline.
: forall arg
(Harg : ZRange.type.option.is_bounded_by arg_bounds arg = true),
ZRange.type.option.is_bounded_by out_bounds (Interp rv arg) = true
- /\ Interp rv arg = for_reification.Interp E arg.
+ /\ Interp rv arg = app_curried (for_reification.Interp E) arg.
Proof.
cbv [BoundsPipeline_full] in *.
eapply BoundsPipeline_correct_trans; [ eassumption | | eassumption.. ].
intros; erewrite PrePipeline_correct; reflexivity.
Qed.
-
- Lemma BoundsPipelineConst_correct
- (with_dead_code_elimination : bool := true)
- (with_subst01 : bool)
- relax_zrange
- (Hrelax : forall r r' z : zrange,
- (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true)
- {d}
- (e : Expr d)
- bounds
- rv
- (Hrv : BoundsPipelineConst (*with_dead_code_elimination*) with_subst01 relax_zrange e bounds = Success rv)
- : ZRange.type.option.is_bounded_by bounds (Interp rv) = true
- /\ Interp rv = Interp e.
- Proof.
- cbv [BoundsPipelineConst BoundsPipeline_0_or_1 Let_In] in *;
- repeat match goal with
- | [ H : match ?x with _ => _ end = Success _ |- _ ]
- => destruct x eqn:?; cbv beta iota in H; [ | congruence ];
- let H' := fresh in
- inversion H as [H']; clear H; rename H' into H
- end.
- { intros;
- match goal with
- | [ H : _ = _ |- _ ]
- => eapply CheckedPartialEvaluateWithBounds0_Correct in H;
- [ destruct H as [H0 H1] | .. ]
- end;
- [
- | eassumption || (try reflexivity).. ].
- refine (let H' := admit (* interp correctness *) in
- conj _ (eq_trans H' _));
- clearbody H'.
- { rewrite H'; eassumption. }
- { rewrite H0.
- exact admit. (* interp correctness *) } }
- Qed.
-
- Definition BoundsPipelineConst_correct_transT
- {t}
- out_bounds
- (InterpE : type.interp t)
- (rv : Expr t)
- := ZRange.type.option.is_bounded_by out_bounds (Interp rv) = true
- /\ Interp rv = InterpE.
-
- Lemma BoundsPipelineConst_correct_trans
- (with_dead_code_elimination : bool := true)
- (with_subst01 : bool)
- relax_zrange
- (Hrelax
- : forall r r' z : zrange,
- (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true)
- {t}
- (e : Expr t)
- out_bounds
- (InterpE : type.interp t)
- (InterpE_correct : Interp e = InterpE)
- rv
- (Hrv : BoundsPipelineConst (*with_dead_code_elimination*) with_subst01 relax_zrange e out_bounds = Success rv)
- : BoundsPipelineConst_correct_transT out_bounds InterpE rv.
- Proof.
- rewrite <- InterpE_correct.
- eapply @BoundsPipelineConst_correct; eassumption.
- Qed.
-
- Definition BoundsPipelineConst_full
- (with_dead_code_elimination : bool := true)
- (with_subst01 : bool)
- relax_zrange
- {t}
- (E : for_reification.Expr t)
- out_bounds
- : ErrorT (Expr t)
- := let E := PrePipeline E in
- @BoundsPipelineConst
- (*with_dead_code_elimination*)
- with_subst01
- relax_zrange
- t E out_bounds.
-
- Lemma BoundsPipelineConst_full_correct
- (with_dead_code_elimination : bool := true)
- (with_subst01 : bool)
- relax_zrange
- (Hrelax : forall r r' z : zrange,
- (z <=? r)%zrange = true -> relax_zrange r = Some r' -> (z <=? r')%zrange = true)
- {t}
- (E : for_reification.Expr t)
- out_bounds
- rv
- (Hrv : BoundsPipelineConst_full (*with_dead_code_elimination*) with_subst01 relax_zrange E out_bounds = Success rv)
- : ZRange.type.option.is_bounded_by out_bounds (Interp rv) = true
- /\ Interp rv = for_reification.Interp E.
- Proof.
- cbv [BoundsPipelineConst_full] in *.
- eapply BoundsPipelineConst_correct_trans; [ eassumption | | eassumption.. ].
- intros; erewrite PrePipeline_correct; reflexivity.
- Qed.
End Pipeline.
Definition round_up_bitwidth_gen (possible_values : list Z) (bitwidth : Z) : option Z
@@ -7176,19 +6770,13 @@ Section rcarry_mul.
relax_zrange
rop%Expr in_bounds out_bounds).
- Notation BoundsPipelineConst rop out_bounds
- := (Pipeline.BoundsPipelineConst
- (*false*) true
- relax_zrange
- rop%Expr out_bounds).
-
Notation BoundsPipeline_correct in_bounds out_bounds op
- := (fun rv (rop : Expr (type.reify_type_of op%function)) Hrop
+ := (fun rv (rop : Expr (type.reify_type_of op)) Hrop
=> @Pipeline.BoundsPipeline_correct_trans
(*false*) true
relax_zrange
(relax_zrange_gen_good _)
- _ _
+ _
rop
in_bounds
out_bounds
@@ -7196,19 +6784,6 @@ Section rcarry_mul.
Hrop rv)
(only parsing).
- Notation BoundsPipelineConst_correct out_bounds op
- := (fun rv (rop : Expr (type.reify_type_of op)) Hrop
- => @Pipeline.BoundsPipelineConst_correct_trans
- (*false*) true
- relax_zrange
- (relax_zrange_gen_good _)
- _
- rop%Expr
- out_bounds
- op
- Hrop rv)
- (only parsing).
-
(* N.B. We only need [rcarry_mul] if we want to extract the Pipeline; otherwise we can just use [rcarry_mul_correct] *)
Definition rcarry_mul
:= BoundsPipeline
@@ -7260,12 +6835,14 @@ Section rcarry_mul.
(encodemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)).
Definition rzero_correct
- := BoundsPipelineConst_correct
+ := BoundsPipeline_correct
+ tt
(Some tight_bounds)
(zeromod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)).
Definition rone_correct
- := BoundsPipelineConst_correct
+ := BoundsPipeline_correct
+ tt
(Some tight_bounds)
(onemod (Qnum limbwidth) (Z.pos (Qden limbwidth)) s c n (List.length c)).
@@ -7363,8 +6940,8 @@ Section rcarry_mul.
(Interp raddv)
(Interp rsubv)
(Interp roppv)
- (Interp rzerov)
- (Interp ronev)
+ (Interp rzerov tt)
+ (Interp ronev tt)
(Interp rencodev).
Theorem Good : GoodT.
@@ -7382,9 +6959,12 @@ Section rcarry_mul.
| apply conj
| progress intros
| progress cbv [onemod zeromod]
+ | eapply Hrzerov (* to handle diff with whether or not correctness asks for boundedness of tt *)
+ | eapply Hronev (* to handle diff with whether or not correctness asks for boundedness of tt *)
| match goal with
| [ |- ?x = ?x ] => reflexivity
| [ |- ?x = ?ev ] => is_evar ev; reflexivity
+ | [ |- ZRange.type.option.is_bounded_by tt tt = true ] => reflexivity
end ].
Qed.
End make_ring.
@@ -7400,6 +6980,9 @@ Proof. cbv [pointwise_relation]; intros; subst; trivial. Qed.
Ltac peel_interp_app _ :=
lazymatch goal with
+ | [ |- ?R' (?InterpE ?arg) (?f ?arg) ]
+ => apply fg_equal_rel; [ | reflexivity ];
+ try peel_interp_app ()
| [ |- ?R' (Interp ?ev) (?f ?x) ]
=> let sv := type of x in
let fx := constr:(f x) in
@@ -7420,10 +7003,9 @@ Ltac peel_interp_app _ :=
end ] ]
end.
Ltac pre_cache_reify _ :=
+ cbv [app_curried];
let arg := fresh "arg" in
- (tryif intros arg _
- then apply fg_equal_rel; [ | reflexivity ]
- else hnf);
+ intros arg _;
peel_interp_app ();
[ lazymatch goal with
| [ |- ?R (Interp ?ev) _ ]
@@ -7733,7 +7315,6 @@ Module X25519_64.
base_51_encode_correct.
Print Assumptions base_51_good.
-
Import PrintingNotations.
Print base_51_carry_mul.
(*base_51_carry_mul =
@@ -8304,12 +7885,12 @@ Module MontgomeryReduction.
else res.
Notation BoundsPipeline_correct in_bounds out_bounds op
- := (fun rv (rop : Expr (type.reify_type_of op%function)) Hrop
+ := (fun rv (rop : Expr (type.reify_type_of op)) Hrop
=> @Pipeline.BoundsPipeline_correct_trans
false (* subst01 *)
relax_zrange
(relax_zrange_gen_good _)
- _ _
+ _
rop
in_bounds
out_bounds