diff options
-rw-r--r-- | pretyping/tacred.ml | 129 | ||||
-rw-r--r-- | pretyping/tacred.mli | 13 | ||||
-rw-r--r-- | test-suite/success/simpl_tuning.v | 149 |
3 files changed, 280 insertions, 11 deletions
diff --git a/pretyping/tacred.ml b/pretyping/tacred.ml index b6c67555d..4cf1e2345 100644 --- a/pretyping/tacred.ml +++ b/pretyping/tacred.ml @@ -16,6 +16,7 @@ open Termops open Namegen open Declarations open Inductive +open Libobject open Environ open Closure open Reductionops @@ -515,27 +516,127 @@ let special_red_case env sigma whfun (ci, p, c, lf) = in redrec (c, empty_stack) +(* data structure to hold the map kn -> rec_args for simpl *) + +type behaviour = { + b_nargs: int; + b_recargs: int list; + b_dont_expose_case: bool; +} + +let behaviour_table = ref (Refmap.empty : behaviour Refmap.t) + +let _ = + Summary.declare_summary "simplbehaviour" + { Summary.freeze_function = (fun () -> !behaviour_table); + Summary.unfreeze_function = (fun x -> behaviour_table := x); + Summary.init_function = (fun () -> behaviour_table := Refmap.empty) } + +type simpl_flag = [ `SimplDontExposeCase | `SimplNeverUnfold ] +type req = + | ReqLocal + | ReqGlobal of global_reference * (int list * int * simpl_flag list) + +let load_simpl_behaviour _ (_,(_,(r, b))) = + behaviour_table := Refmap.add r b !behaviour_table + +let cache_simpl_behaviour o = load_simpl_behaviour 1 o + +let classify_simpl_behaviour = function + | ReqLocal, _ -> Dispose + | ReqGlobal _, _ as o -> Substitute o + +let subst_simpl_behaviour (subst, (_, (r,o as orig))) = + ReqLocal, + let r' = fst (subst_global subst r) in if r==r' then orig else (r',o) + +let discharge_simpl_behaviour = function + | _,(ReqGlobal (ConstRef c, req), (_, b)) -> + let c' = pop_con c in + let vars = Lib.section_segment_of_constant c in + let extra = List.length vars in + let nargs' = if b.b_nargs < 0 then b.b_nargs else b.b_nargs + extra in + let recargs' = List.map ((+) extra) b.b_recargs in + let b' = { b with b_nargs = nargs'; b_recargs = recargs' } in + Some (ReqGlobal (ConstRef c', req), (ConstRef c', b')) + | _ -> None + +let rebuild_simpl_behaviour = function + | req, (ConstRef c, _ as x) -> req, x + | _ -> assert false + +let inSimplBehaviour = declare_object { (default_object "SIMPLBEHAVIOUR") with + load_function = load_simpl_behaviour; + cache_function = cache_simpl_behaviour; + classify_function = classify_simpl_behaviour; + subst_function = subst_simpl_behaviour; + discharge_function = discharge_simpl_behaviour; + rebuild_function = rebuild_simpl_behaviour; +} + +let set_simpl_behaviour local r (recargs, nargs, flags as req) = + let nargs = if List.mem `SimplNeverUnfold flags then max_int else nargs in + let nargs = List.fold_left max nargs recargs in + let behaviour = { + b_nargs = nargs; b_recargs = recargs; + b_dont_expose_case = List.mem `SimplDontExposeCase flags } in + let req = if local then ReqLocal else ReqGlobal (r, req) in + Lib.add_anonymous_leaf (inSimplBehaviour (req, (r, behaviour))) +;; + +let get_simpl_behaviour r = + try + let b = Refmap.find r !behaviour_table in + let flags = + if b.b_nargs = max_int then [`SimplNeverUnfold] + else if b.b_dont_expose_case then [`SimplDontExposeCase] else [] in + Some (b.b_recargs, (if b.b_nargs = max_int then -1 else b.b_nargs), flags) + with Not_found -> None + +let get_behaviour = function + | EvalVar _ | EvalRel _ | EvalEvar _ -> raise Not_found + | EvalConst c -> Refmap.find (ConstRef c) !behaviour_table + +let recargs r = + try let b = get_behaviour r in Some (b.b_recargs, b.b_nargs) + with Not_found -> None + +let dont_expose_case r = + try (get_behaviour r).b_dont_expose_case with Not_found -> false + (* [red_elim_const] contracts iota/fix/cofix redexes hidden behind constants by keeping the name of the constants in the recursive calls; it fails if no redex is around *) let rec red_elim_const env sigma ref largs = - match reference_eval sigma env ref with - | EliminationCases n when stack_args_size largs >= n -> + let nargs = stack_args_size largs in + let largs, unfold_anyway = + match recargs ref with + | None -> largs, false + | Some (_,n) when nargs < n -> raise Redelimination + | Some (l,n) -> + List.fold_left (fun stack i -> + let arg = stack_nth stack i in + let rarg = whd_construct_state env sigma (arg, empty_stack) in + match kind_of_term (fst rarg) with + | Construct _ -> stack_assign stack i (app_stack rarg) + | _ -> raise Redelimination) + largs l, n >= 0 && l = [] && nargs >= n in + try match reference_eval sigma env ref with + | EliminationCases n when nargs >= n -> let c = reference_value sigma env ref in let c', lrest = whd_betadelta_state env sigma (c,largs) in let whfun = whd_simpl_state env sigma in (special_red_case env sigma whfun (destCase c'), lrest) - | EliminationFix (min,minfxargs,infos) when stack_args_size largs >=min -> + | EliminationFix (min,minfxargs,infos) when nargs >= min -> let c = reference_value sigma env ref in let d, lrest = whd_betadelta_state env sigma (c,largs) in let f = make_elim_fun ([|Some (minfxargs,ref)|],infos) largs in let whfun = whd_construct_state env sigma in (match reduce_fix_use_function env sigma f whfun (destFix d) lrest with | NotReducible -> raise Redelimination - | Reduced (c,rest) -> (nf_beta sigma c, rest)) - | EliminationMutualFix (min,refgoal,refinfos) - when stack_args_size largs >= min -> + | Reduced (c,rest) -> (nf_beta sigma c, rest)) + | EliminationMutualFix (min,refgoal,refinfos) when nargs >= min -> let rec descend ref args = let c = reference_value sigma env ref in if ref = refgoal then @@ -551,6 +652,9 @@ let rec red_elim_const env sigma ref largs = | NotReducible -> raise Redelimination | Reduced (c,rest) -> (nf_beta sigma c, rest)) | _ -> raise Redelimination + with Redelimination when unfold_anyway -> + let c = reference_value sigma env ref in + whd_betaiotazeta sigma (app_stack (c, largs)), empty_stack (* reduce to whd normal form or to an applied constant that does not hide a reducible iota/fix/cofix redex (the "simpl" tactic) *) @@ -578,7 +682,14 @@ and whd_simpl_state env sigma s = | _ when isEvalRef env x -> let ref = destEvalRef x in (try - redrec (red_elim_const env sigma ref stack) + let hd, _ as s' = redrec (red_elim_const env sigma ref stack) in + let rec is_case x = match kind_of_term x with + | Lambda (_,_, x) | LetIn (_,_,_, x) | Cast (x, _,_) -> is_case x + | App (hd, _) -> is_case hd + | Case _ -> true + | _ -> false in + if dont_expose_case ref && is_case hd then raise Redelimination + else s' with Redelimination -> s) | _ -> s @@ -919,7 +1030,7 @@ let one_step_reduce env sigma c = | _ when isEvalRef env x -> let ref = destEvalRef x in (try - red_elim_const env sigma ref stack + red_elim_const env sigma ref stack with Redelimination -> match reference_opt_value sigma env ref with | Some d -> d, stack @@ -959,7 +1070,7 @@ let reduce_to_ref_gen allow_product env sigma ref t = with Not_found -> try let t' = nf_betaiota sigma (one_step_reduce env sigma t) in - elimrec env t' l + elimrec env t' l with NotStepReducible -> errorlabstrm "" (str "Cannot recognize a statement based on " ++ diff --git a/pretyping/tacred.mli b/pretyping/tacred.mli index d713eae42..f0f5f66d3 100644 --- a/pretyping/tacred.mli +++ b/pretyping/tacred.mli @@ -15,6 +15,7 @@ open Closure open Glob_term open Termops open Pattern +open Libnames type reduction_tactic_error = InvalidAbstraction of env * constr * (env * Type_errors.type_error) @@ -43,6 +44,14 @@ val red_product : reduction_function (** Red (raise Redelimination if nothing reducible) *) val try_red_product : reduction_function +(** Tune the behaviour of simpl for the given constant name *) +type simpl_flag = [ `SimplDontExposeCase | `SimplNeverUnfold ] + +val set_simpl_behaviour : + bool -> global_reference -> (int list * int * simpl_flag list) -> unit +val get_simpl_behaviour : + global_reference -> (int list * int * simpl_flag list) option + (** Simpl *) val simpl : reduction_function @@ -85,10 +94,10 @@ val reduce_to_quantified_ind : env -> evar_map -> types -> inductive * types (** [reduce_to_quantified_ref env sigma ref t] try to put [t] in the form [t'=(x1:A1)..(xn:An)(ref args)] and fails with user error if not possible *) val reduce_to_quantified_ref : - env -> evar_map -> Libnames.global_reference -> types -> types + env -> evar_map -> global_reference -> types -> types val reduce_to_atomic_ref : - env -> evar_map -> Libnames.global_reference -> types -> types + env -> evar_map -> global_reference -> types -> types val contextually : bool -> occurrences * constr_pattern -> (patvar_map -> reduction_function) -> reduction_function diff --git a/test-suite/success/simpl_tuning.v b/test-suite/success/simpl_tuning.v new file mode 100644 index 000000000..d4191b939 --- /dev/null +++ b/test-suite/success/simpl_tuning.v @@ -0,0 +1,149 @@ +(* as it is dynamically inferred by simpl *) +Arguments minus !n / m. + +Lemma foo x y : S (S x) - S y = 0. +simpl. +match goal with |- (match y with O => S x | S _ => _ end = 0) => idtac end. +Abort. + +(* we avoid exposing a match *) +Arguments minus n m : simpl nomatch. + +Lemma foo x : minus 0 x = 0. +simpl. +match goal with |- (0 = 0) => idtac end. +Abort. + +Lemma foo x y : S (S x) - S y = 0. +simpl. +match goal with |- (S x - y = 0) => idtac end. +Abort. + +Lemma foo x y : S (S x) - (S (match y with O => O | S z => S z end)) = 0. +simpl. +match goal with |-(S x - (match y with O => _ | S _ => _ end) = 0) => idtac end. +Abort. + +(* we unfold as soon as we have 1 args, but we avoid exposing a match *) +Arguments minus n / m : simpl nomatch. + +Lemma foo : minus 0 = fun x => 0. +simpl. +match goal with |- minus 0 = _ => idtac end. +Abort. +(* This does not work as one may expect. The point is that simpl is implemented + as "strong (whd_simpl_state)" and after unfolding minus you have + (fun m => match 0 => 0 | S n => ...) that is already in whd and exposes + a match, that of course "strong" would reduce away but at that stage + we don't know, and reducing by hand under the lambda is against whd *) + +(* extra tuning for the usual heuristic *) +Arguments minus !n / m : simpl nomatch. + +Lemma foo x y : S (S x) - S y = 0. +simpl. +match goal with |- (S x - y = 0) => idtac end. +Abort. + +Lemma foo x y : S (S x) - (S (match y with O => O | S z => S z end)) = 0. +simpl. +match goal with |-(S x - (match y with O => _ | S _ => _ end) = 0) => idtac end. +Abort. + +(* full control *) +Arguments minus !n !m /. + +Lemma foo x y : S (S x) - S y = 0. +simpl. +match goal with |- (S x - y = 0) => idtac end. +Abort. + +Lemma foo x y : S (S x) - (S (match y with O => O | S z => S z end)) = 0. +simpl. +match goal with |-(S x - (match y with O => _ | S _ => _ end) = 0) => idtac end. +Abort. + +(* omitting /, that being immediately after the last ! is irrelevant *) +Arguments minus !n !m. + +Lemma foo x y : S (S x) - S y = 0. +simpl. +match goal with |- (S x - y = 0) => idtac end. +Abort. + +Lemma foo x y : S (S x) - (S (match y with O => O | S z => S z end)) = 0. +simpl. +match goal with |-(S x - (match y with O => _ | S _ => _ end) = 0) => idtac end. +Abort. + +Definition pf (D1 C1 : Type) (f : D1 -> C1) (D2 C2 : Type) (g : D2 -> C2) := + fun x => (f (fst x), g (snd x)). + +Delimit Scope foo_scope with F. +Notation "@@" := nat (only parsing) : foo_scope. +Notation "@@" := (fun x => x) (only parsing). + +Arguments pf {D1%F C1%type} f [D2 C2] g x : simpl never. + +Lemma foo x : @pf @@ nat @@ nat nat @@ x = pf @@ @@ x. +Abort. + +Definition fcomp A B C f (g : A -> B) (x : A) : C := f (g x). + +(* fcomp is unfolded if applied to 6 args *) +Arguments fcomp {A B C}%type f g x /. + +Notation "f \o g" := (fcomp f g) (at level 50). + +Lemma foo (f g h : nat -> nat) x : pf (f \o g) h x = pf f h (g (fst x), snd x). +simpl. +match goal with |- (pf (f \o g) h x = _) => idtac end. +case x; intros x1 x2. +simpl. +match goal with |- (pf (f \o g) h _ = pf f h _) => idtac end. +unfold pf; simpl. +match goal with |- (f (g x1), h x2) = (f (g x1), h x2) => idtac end. +Abort. + +Definition volatile := fun x : nat => x. +Arguments volatile /. + +Lemma foo : volatile = volatile. +simpl. +match goal with |- (fun _ => _) = _ => idtac end. +Abort. + +Set Implicit Arguments. + +Section S1. + +Variable T1 : Type. + +Section S2. + +Variable T2 : Type. + +Fixpoint f (x : T1) (y : T2) n (v : unit) m {struct n} : nat := + match n, m with + | 0,_ => 0 + | S _, 0 => n + | S n', S m' => f x y n' v m' end. + +Global Arguments f x y !n !v !m. + +Lemma foo x y n m : f x y (S n) tt m = f x y (S n) tt (S m). +simpl. +match goal with |- (f _ _ _ _ _ = f _ _ _ _ _) => idtac end. +Abort. + +End S2. + +Lemma foo T x y n m : @f T x y (S n) tt m = @f T x y (S n) tt (S m). +simpl. +match goal with |- (f _ _ _ _ _ = f _ _ _ _ _) => idtac end. +Abort. + +End S1. + +Arguments f : clear implicits and scopes. + |