aboutsummaryrefslogtreecommitdiffhomepage
path: root/tactics/auto.ml
diff options
context:
space:
mode:
authorGravatar Matthieu Sozeau <matthieu.sozeau@inria.fr>2014-09-15 21:33:48 +0200
committerGravatar Matthieu Sozeau <matthieu.sozeau@inria.fr>2014-09-15 21:37:31 +0200
commit62a552b508b747b6cdf4bd818233f001ae4ce555 (patch)
tree80feb19c8d02935b550c7f6c971ea42fc39020b2 /tactics/auto.ml
parent1c113305039857ca219f252f5e80f4b179a39082 (diff)
Add a "Hint Mode ref (+ | -)*" hint for setting a global mode
of resulution for goals whose head is "ref". + means the argument is an input and shouldn't contain an evar, otherwise resolution fails. This generalizes the Typeclasses Strict Resolution option which prevents resolution to fire on underconstrained typeclass constraints, now the criterion can be applied to specific parameters. Also cleanup auto/eauto code, uncovering a potential backwards compatibility issue: in cases the goal contains existentials, we never use the discrimination net in auto/eauto. We should try to set this on once the contribs are stabilized (the stdlib goes through when the dnet is used in these cases).
Diffstat (limited to 'tactics/auto.ml')
-rw-r--r--tactics/auto.ml173
1 files changed, 115 insertions, 58 deletions
diff --git a/tactics/auto.ml b/tactics/auto.ml
index e386728fe..0f1d7cb02 100644
--- a/tactics/auto.ml
+++ b/tactics/auto.ml
@@ -137,9 +137,10 @@ module Bounded_net = Btermdn.Make(struct
let compare = pri_order_int
end)
-type search_entry = stored_data list * stored_data list * Bounded_net.t
+type search_entry = stored_data list * stored_data list * Bounded_net.t * bool array list
-let empty_se = ([],[],Bounded_net.create ())
+
+let empty_se = ([],[],Bounded_net.create (),[])
let eq_pri_auto_tactic (_, x) (_, y) =
if Int.equal x.pri y.pri && Option.equal constr_pattern_eq x.pat y.pat then
@@ -157,20 +158,25 @@ let eq_pri_auto_tactic (_, x) (_, y) =
else
false
-let add_tac pat t st (l,l',dn) =
+let add_tac pat t st (l,l',dn,m) =
match pat with
- | None -> if not (List.exists (eq_pri_auto_tactic t) l) then (insert t l, l', dn) else (l, l', dn)
+ | None ->
+ if not (List.exists (eq_pri_auto_tactic t) l) then (insert t l, l', dn, m)
+ else (l, l', dn, m)
| Some pat ->
- if not (List.exists (eq_pri_auto_tactic t) l')
- then (l, insert t l', Bounded_net.add st dn (pat,t)) else (l, l', dn)
-
-let rebuild_dn st ((l,l',dn) : search_entry) =
- (l, l', List.fold_left (fun dn (id, t) -> Bounded_net.add (Some st) dn (Option.get t.pat, (id, t)))
- (Bounded_net.create ()) l')
-
+ if not (List.exists (eq_pri_auto_tactic t) l')
+ then (l, insert t l', Bounded_net.add st dn (pat,t), m) else (l, l', dn, m)
+
+let rebuild_dn st ((l,l',dn,m) : search_entry) =
+ let dn' =
+ List.fold_left
+ (fun dn (id, t) -> Bounded_net.add (Some st) dn (Option.get t.pat, (id, t)))
+ (Bounded_net.create ()) l'
+ in
+ (l, l', dn', m)
-let lookup_tacs (hdc,c) st (l,l',dn) =
- let l' = Bounded_net.lookup st dn c in
+let lookup_tacs concl st (l,l',dn) =
+ let l' = Bounded_net.lookup st dn concl in
let sl' = List.stable_sort pri_order_int l' in
List.merge pri_order_int l sl'
@@ -378,18 +384,43 @@ module Hint_db = struct
let realize_tac (id,tac) = tac
+ let matches_mode args mode =
+ Array.length args == Array.length mode &&
+ Array.for_all2 (fun arg m -> not (m && occur_existential arg)) args mode
+
+ let matches_modes args modes =
+ if List.is_empty modes then true
+ else List.exists (matches_mode args) modes
+
let map_none db =
List.map realize_tac (Sort.merge pri_order (List.map snd db.hintdb_nopat) [])
let map_all k db =
- let (l,l',_) = find k db in
+ let (l,l',_,_) = find k db in
List.map realize_tac (Sort.merge pri_order (List.map snd db.hintdb_nopat @ l) l')
-
- let map_auto (k,c) db =
- let st = if db.use_dn then Some db.hintdb_state else None in
- let l' = lookup_tacs (k,c) st (find k db) in
+
+ (** Precondition: concl has no existentials *)
+ let map_auto (k,args) concl db =
+ let (l,l',dn,m) = find k db in
+ let st = if db.use_dn then (Some db.hintdb_state) else None in
+ let l' = lookup_tacs concl st (l,l',dn) in
List.map realize_tac (Sort.merge pri_order (List.map snd db.hintdb_nopat) l')
+ let map_existential (k,args) concl db =
+ let (l,l',_,m) = find k db in
+ if matches_modes args m then
+ List.map realize_tac (Sort.merge pri_order (List.map snd db.hintdb_nopat @ l) l')
+ else List.map realize_tac (List.map snd db.hintdb_nopat)
+
+ (* [c] contains an existential *)
+ let map_eauto (k,args) concl db =
+ let (l,l',dn,m) = find k db in
+ if matches_modes args m then
+ let st = if db.use_dn then Some db.hintdb_state else None in
+ let l' = lookup_tacs concl st (l,l',dn) in
+ List.map realize_tac (Sort.merge pri_order (List.map snd db.hintdb_nopat) l')
+ else List.map realize_tac (List.map snd db.hintdb_nopat)
+
let is_exact = function
| Give_exact _ -> true
| _ -> false
@@ -446,10 +477,10 @@ module Hint_db = struct
let add_list l db = List.fold_left (fun db k -> add_one k db) db l
let remove_sdl p sdl = List.smartfilter p sdl
- let remove_he st p (sl1, sl2, dn as he) =
+ let remove_he st p (sl1, sl2, dn, m as he) =
let sl1' = remove_sdl p sl1 and sl2' = remove_sdl p sl2 in
if sl1' == sl1 && sl2' == sl2 then he
- else rebuild_dn st (sl1', sl2', dn)
+ else rebuild_dn st (sl1', sl2', dn, m)
let remove_list grs db =
let filter (_, h) =
@@ -461,12 +492,12 @@ module Hint_db = struct
let remove_one gr db = remove_list [gr] db
let iter f db =
- f None (List.map (fun x -> realize_tac (snd x)) db.hintdb_nopat);
- Constr_map.iter (fun k (l,l',_) -> f (Some k) (List.map realize_tac (l@l'))) db.hintdb_map
+ f None [] (List.map (fun x -> realize_tac (snd x)) db.hintdb_nopat);
+ Constr_map.iter (fun k (l,l',_,m) -> f (Some k) m (List.map realize_tac (l@l'))) db.hintdb_map
let fold f db accu =
- let accu = f None (List.map (fun x -> snd (snd x)) db.hintdb_nopat) accu in
- Constr_map.fold (fun k (l,l',_) -> f (Some k) (List.map snd (l@l'))) db.hintdb_map accu
+ let accu = f None [] (List.map (fun x -> snd (snd x)) db.hintdb_nopat) accu in
+ Constr_map.fold (fun k (l,l',_,m) -> f (Some k) m (List.map snd (l@l'))) db.hintdb_map accu
let transparent_state db = db.hintdb_state
@@ -477,6 +508,10 @@ module Hint_db = struct
let add_cut path db =
{ db with hintdb_cut = normalize_path (PathOr (db.hintdb_cut, path)) }
+ let add_mode gr m db =
+ let (l,l',dn,ms) = find gr db in
+ { db with hintdb_map = Constr_map.add gr (l,l',dn,m :: ms) db.hintdb_map }
+
let cut db = db.hintdb_cut
let unfolds db = db.hintdb_unfolds
@@ -648,6 +683,17 @@ let make_extern pri pat tacast =
name = PathAny;
code = Extern tacast })
+let make_mode ref m =
+ let ty = Global.type_of_global_unsafe ref in
+ let ctx, t = decompose_prod ty in
+ let n = List.length ctx in
+ let m' = Array.of_list m in
+ if not (n == Array.length m') then
+ errorlabstrm "Hint"
+ (pr_global ref ++ str" has " ++ int n ++
+ str" arguments while the mode declares " ++ int (Array.length m'))
+ else m'
+
let make_trivial env sigma poly ?(name=PathAny) r =
let c,ctx = fresh_global_or_constr env sigma poly r in
let t = hnf_constr env sigma (type_of env sigma c) in
@@ -699,12 +745,18 @@ type hint_action =
| AddHints of hint_entry list
| RemoveHints of global_reference list
| AddCut of hints_path
+ | AddMode of global_reference * bool array
let add_cut dbname path =
let db = get_db dbname in
let db' = Hint_db.add_cut path db in
searchtable_add (dbname, db')
+let add_mode dbname l m =
+ let db = get_db dbname in
+ let db' = Hint_db.add_mode l m db in
+ searchtable_add (dbname, db')
+
type hint_obj = bool * string * hint_action (* locality, name, action *)
let cache_autohint (_,(local,name,hints)) =
@@ -714,19 +766,10 @@ let cache_autohint (_,(local,name,hints)) =
| AddHints hints -> add_hint name hints
| RemoveHints grs -> remove_hint name grs
| AddCut path -> add_cut name path
+ | AddMode (l, m) -> add_mode name l m
let (forward_subst_tactic, extern_subst_tactic) = Hook.make ()
- (* let subst_mps_or_ref subst cr = *)
- (* match cr with *)
- (* | IsConstr c -> let c' = subst_mps subst c in *)
- (* if c' == c then cr *)
- (* else IsConstr c' *)
- (* | IsGlobal r -> let r' = subst_global_reference subst r in *)
- (* if r' == r then cr *)
- (* else IsGlobal r' *)
- (* in *)
-
let subst_autohint (subst,(local,name,hintlist as obj)) =
let subst_key gr =
let (lab'', elab') = subst_global subst gr in
@@ -779,11 +822,14 @@ let subst_autohint (subst,(local,name,hintlist as obj)) =
if hintlist' == hintlist then obj else
(local,name,AddHints hintlist')
| RemoveHints grs ->
- let grs' = List.smartmap (fun x -> fst (subst_global subst x)) grs in
+ let grs' = List.smartmap (subst_global_reference subst) grs in
if grs==grs' then obj else (local, name, RemoveHints grs')
| AddCut path ->
let path' = subst_hints_path subst path in
if path' == path then obj else (local, name, AddCut path')
+ | AddMode (l,m) ->
+ let l' = subst_global_reference subst l in
+ (local, name, AddMode (l', m))
let classify_autohint ((local,name,hintlist) as obj) =
match hintlist with
@@ -835,6 +881,13 @@ let add_cuts l local dbnames =
(inAutoHint (local,dbname, AddCut l)))
dbnames
+let add_mode l m local dbnames =
+ List.iter
+ (fun dbname -> Lib.add_anonymous_leaf
+ (let m' = make_mode l m in
+ (inAutoHint (local,dbname, AddMode (l,m')))))
+ dbnames
+
let add_transparency l b local dbnames =
List.iter
(fun dbname -> Lib.add_anonymous_leaf
@@ -870,6 +923,7 @@ type hints_entry =
| HintsCutEntry of hints_path
| HintsUnfoldEntry of evaluable_global_reference list
| HintsTransparencyEntry of evaluable_global_reference list * bool
+ | HintsModeEntry of global_reference * bool list
| HintsExternEntry of
int * (patvar list * constr_pattern) option * glob_tactic_expr
@@ -914,22 +968,19 @@ let interp_hints poly =
let f c =
let evd,c = Constrintern.interp_open_constr (Global.env()) Evd.empty c in
prepare_hint true (Global.env()) Evd.empty (evd,c) in
- let fr r =
+ let fref r =
let gr = global_with_alias r in
- let r' = evaluable_of_global_reference (Global.env()) gr in
Dumpglob.add_glob (loc_of_reference r) gr;
- r' in
+ gr in
+ let fr r =
+ evaluable_of_global_reference (Global.env()) (fref r)
+ in
let fi c =
match c with
| HintsReference c ->
let gr = global_with_alias c in
(PathHints [gr], poly, IsGlobRef gr)
- | HintsConstr c ->
- (* if poly then *)
- (* errorlabstrm "Hint" (Ppconstr.pr_constr_expr c ++ spc () ++ *)
- (* str" is a term and cannot be made a polymorphic hint," ++ *)
- (* str" only global references can be polymorphic hints.") *)
- (* else *) (PathAny, poly, f c)
+ | HintsConstr c -> (PathAny, poly, f c)
in
let fres (pri, b, r) =
let path, poly, gr = fi r in
@@ -942,6 +993,7 @@ let interp_hints poly =
| HintsUnfold lhints -> HintsUnfoldEntry (List.map fr lhints)
| HintsTransparency (lhints, b) ->
HintsTransparencyEntry (List.map fr lhints, b)
+ | HintsMode (r, l) -> HintsModeEntry (fref r, l)
| HintsConstructors lqid ->
let constr_hints_of_ind qid =
let ind = global_inductive_with_alias qid in
@@ -968,6 +1020,7 @@ let add_hints local dbnames0 h =
| HintsResolveEntry lhints -> add_resolves env sigma lhints local dbnames
| HintsImmediateEntry lhints -> add_trivials env sigma lhints local dbnames
| HintsCutEntry lhints -> add_cuts lhints local dbnames
+ | HintsModeEntry (l,m) -> add_mode l m local dbnames
| HintsUnfoldEntry lhints -> add_unfolds lhints local dbnames
| HintsTransparencyEntry (lhints, b) ->
add_transparency lhints b local dbnames
@@ -1030,11 +1083,10 @@ let pr_hint_term cl =
let dbs = current_db () in
let valid_dbs =
let fn = try
- let hdc = head_constr_bound cl in
- let hd = head_of_constr_reference hdc in
+ let hdc = decompose_app_bound cl in
if occur_existential cl then
- Hint_db.map_all hd
- else Hint_db.map_auto (hd, cl)
+ Hint_db.map_existential hdc cl
+ else Hint_db.map_auto hdc cl
with Bound -> Hint_db.map_none
in
let fn db = List.map (fun x -> 0, x) (fn db) in
@@ -1063,11 +1115,16 @@ let pr_applicable_hint () =
(* displays the whole hint database db *)
let pr_hint_db db =
+ let pr_mode = prvect_with_sep spc (fun x -> if x then str"+" else str"-") in
+ let pr_modes l =
+ if List.is_empty l then mt ()
+ else str" (modes " ++ prlist_with_sep pr_comma pr_mode l ++ str")"
+ in
let content =
- let fold head hintlist accu =
+ let fold head modes hintlist accu =
let goal_descr = match head with
| None -> str "For any goal"
- | Some head -> str "For " ++ pr_global head
+ | Some head -> str "For " ++ pr_global head ++ pr_modes modes
in
let hints = pr_hint_list (List.map (fun x -> (0, x)) hintlist) in
let hint_descr = hov 0 (goal_descr ++ str " -> " ++ hints) in
@@ -1395,8 +1452,8 @@ let hintmap_of hdc concl =
match hdc with
| None -> Hint_db.map_none
| Some hdc ->
- if occur_existential concl then Hint_db.map_all hdc
- else Hint_db.map_auto (hdc,concl)
+ if occur_existential concl then Hint_db.map_existential hdc concl
+ else Hint_db.map_auto hdc concl
let exists_evaluable_reference env = function
| EvalConstRef _ -> true
@@ -1458,8 +1515,8 @@ and my_find_search_delta db_list local_db hdc concl =
match hdc with None -> Hint_db.map_none db
| Some hdc ->
if (Id.Pred.is_empty ids && Cpred.is_empty csts)
- then Hint_db.map_auto (hdc,concl) db
- else Hint_db.map_all hdc db
+ then Hint_db.map_auto hdc concl db
+ else Hint_db.map_existential hdc concl db
in auto_flags_of_state st, l
in List.map (fun x -> (Some flags,x)) l)
(local_db::db_list)
@@ -1489,8 +1546,8 @@ and tac_of_hint dbg db_list local_db concl (flags, ({pat=p; code=t;poly=poly}))
and trivial_resolve dbg mod_delta db_list local_db cl =
try
let head =
- try let hdconstr = head_constr_bound cl in
- Some (head_of_constr_reference hdconstr)
+ try let hdconstr = decompose_app_bound cl in
+ Some hdconstr
with Bound -> None
in
List.map (tac_of_hint dbg db_list local_db cl)
@@ -1543,8 +1600,8 @@ let h_trivial ?(debug=Off) lems l = gen_trivial ~debug lems l
let possible_resolve dbg mod_delta db_list local_db cl =
try
let head =
- try let hdconstr = head_constr_bound cl in
- Some (head_of_constr_reference hdconstr)
+ try let hdconstr = decompose_app_bound cl in
+ Some hdconstr
with Bound -> None
in
List.map (tac_of_hint dbg db_list local_db cl)