aboutsummaryrefslogtreecommitdiffhomepage
path: root/interp
diff options
context:
space:
mode:
authorGravatar Hugo Herbelin <Hugo.Herbelin@inria.fr>2017-08-17 20:12:55 +0200
committerGravatar Hugo Herbelin <Hugo.Herbelin@inria.fr>2018-02-20 10:03:06 +0100
commite4d93d1cef27d3a8c1e36139fc1e118730406f67 (patch)
tree0149d4c6ff1fc4cc978e796f303ee6dcdda65074 /interp
parent50970e4043d73d9a4fbd17ffe765745f6d726317 (diff)
Adding general support for irrefutable disjunctive patterns.
This now works not only for parsing of fun/forall (as in 8.6), but also for arbitraty notations with binders and for printing.
Diffstat (limited to 'interp')
-rw-r--r--interp/constrexpr_ops.ml4
-rw-r--r--interp/constrexpr_ops.mli2
-rw-r--r--interp/constrextern.ml34
-rw-r--r--interp/constrintern.ml86
-rw-r--r--interp/notation_ops.ml120
-rw-r--r--interp/notation_ops.mli6
6 files changed, 153 insertions, 99 deletions
diff --git a/interp/constrexpr_ops.ml b/interp/constrexpr_ops.ml
index 4877bf271..8aca6e333 100644
--- a/interp/constrexpr_ops.ml
+++ b/interp/constrexpr_ops.ml
@@ -547,6 +547,10 @@ let coerce_to_name = function
| { CAst.loc; _ } -> CErrors.user_err ?loc ~hdr:"coerce_to_name"
(str "This expression should be a name.")
+let mkCPatOr ?loc = function
+ | [pat] -> pat
+ | disjpat -> CAst.make ?loc @@ (CPatOr disjpat)
+
let mkAppPattern ?loc p lp =
let open CAst in
make ?loc @@ (match p.v with
diff --git a/interp/constrexpr_ops.mli b/interp/constrexpr_ops.mli
index 9a59d66f4..0b00b0e4d 100644
--- a/interp/constrexpr_ops.mli
+++ b/interp/constrexpr_ops.mli
@@ -54,6 +54,8 @@ val mkCLambdaN : ?loc:Loc.t -> local_binder_expr list -> constr_expr -> constr_e
val mkCProdN : ?loc:Loc.t -> local_binder_expr list -> constr_expr -> constr_expr
(** Same as [prod_constr_expr], with location *)
+val mkCPatOr : ?loc:Loc.t -> cases_pattern_expr list -> cases_pattern_expr
+
val mkAppPattern : ?loc:Loc.t -> cases_pattern_expr -> cases_pattern_expr list -> cases_pattern_expr
(** Apply a list of pattern arguments to a pattern *)
diff --git a/interp/constrextern.ml b/interp/constrextern.ml
index 67e19d125..9e18966b6 100644
--- a/interp/constrextern.ml
+++ b/interp/constrextern.ml
@@ -922,16 +922,21 @@ and extern_typ (_,scopes) =
and sub_extern inctx (_,scopes) = extern inctx (None,scopes)
and factorize_prod scopes vars na bk aty c =
+ let store, get = set_temporary_memory () in
match na, DAst.get c with
- | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],[(_,(_,[p],b))])
- when is_gvar id e ->
- let p = if occur_glob_constr id b then set_pat_alias id p else p in
+ | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],(_::_ as eqns))
+ when is_gvar id e && List.length (store (factorize_eqns eqns)) = 1 ->
+ (match get () with
+ | [(_,(ids,disj_of_patl,b))] ->
+ let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in
+ let disjpat = if occur_glob_constr id b then List.map (set_pat_alias id) disjpat else disjpat in
let b = extern_typ scopes vars b in
- let p = extern_cases_pattern_in_scope scopes vars p in
+ let p = mkCPatOr (List.map (extern_cases_pattern_in_scope scopes vars) disjpat) in
let binder = CLocalPattern (c.loc,(p,None)) in
(match b.v with
| CProdN (bl,b) -> CProdN (binder::bl,b)
| _ -> CProdN ([binder],b))
+ | _ -> assert false)
| _, _ ->
let c = extern_typ scopes vars c in
match na, c.v with
@@ -945,16 +950,21 @@ and factorize_prod scopes vars na bk aty c =
CProdN ([CLocalAssum([Loc.tag na],Default bk,aty)],c)
and factorize_lambda inctx scopes vars na bk aty c =
+ let store, get = set_temporary_memory () in
match na, DAst.get c with
- | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],[(_,(_,[p],b))])
- when is_gvar id e ->
- let p = if occur_glob_constr id b then set_pat_alias id p else p in
+ | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],(_::_ as eqns))
+ when is_gvar id e && List.length (store (factorize_eqns eqns)) = 1 ->
+ (match get () with
+ | [(_,(ids,disj_of_patl,b))] ->
+ let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in
+ let disjpat = if occur_glob_constr id b then List.map (set_pat_alias id) disjpat else disjpat in
let b = sub_extern inctx scopes vars b in
- let p = extern_cases_pattern_in_scope scopes vars p in
+ let p = mkCPatOr (List.map (extern_cases_pattern_in_scope scopes vars) disjpat) in
let binder = CLocalPattern (c.loc,(p,None)) in
(match b.v with
| CLambdaN (bl,b) -> CLambdaN (binder::bl,b)
| _ -> CLambdaN ([binder],b))
+ | _ -> assert false)
| _, _ ->
let c = sub_extern inctx scopes vars c in
match c.v with
@@ -994,7 +1004,7 @@ and extern_local_binder scopes vars = function
| GLocalPattern ((p,_),_,bk,ty) ->
let ty =
if !Flags.raw_print then Some (extern_typ scopes vars ty) else None in
- let p = extern_cases_pattern vars p in
+ let p = mkCPatOr (List.map (extern_cases_pattern vars) p) in
let (assums,ids,l) = extern_local_binder scopes vars l in
(assums,ids, CLocalPattern(Loc.tag @@ (p,ty)) :: l)
@@ -1066,10 +1076,10 @@ and extern_notation (tmp_scope,scopes as allscopes) vars t = function
termlists in
let bl =
List.map (fun (bl,(scopt,scl)) ->
- extern_cases_pattern_in_scope (scopt,scl@scopes') vars bl)
+ mkCPatOr (List.map (extern_cases_pattern_in_scope (scopt,scl@scopes') vars) bl))
binders in
- let bll =
- List.map (fun (bl,(scopt,scl)) ->
+ let bll =
+ List.map (fun (bl,(scopt,scl)) ->
pi3 (extern_local_binder (scopt,scl@scopes') vars bl))
binderlists in
insert_delimiters (make_notation loc ntn (l,ll,bl,bll)) key)
diff --git a/interp/constrintern.ml b/interp/constrintern.ml
index 63cf66bdd..379d09e89 100644
--- a/interp/constrintern.ml
+++ b/interp/constrintern.ml
@@ -441,20 +441,19 @@ let intern_letin_binder intern ntnvars env ((loc,na as locna),def,ty) =
(na,Explicit,term,ty))
let intern_cases_pattern_as_binder ?loc ntnvars env p =
- let il,cp =
- match !intern_cases_pattern_fwd ntnvars (None,env.scopes) p with
- | (il, [(subst,cp)]) ->
- if not (Id.Map.equal Id.equal subst Id.Map.empty) then
- user_err ?loc (str "Unsupported nested \"as\" clause.");
- il,cp
- | _ -> assert false
+ let il,disjpat =
+ let (il, subst_disjpat) = !intern_cases_pattern_fwd ntnvars (None,env.scopes) p in
+ let substl,disjpat = List.split subst_disjpat in
+ if not (List.for_all (fun subst -> Id.Map.equal Id.equal subst Id.Map.empty) substl) then
+ user_err ?loc (str "Unsupported nested \"as\" clause.");
+ il,disjpat
in
let env = List.fold_right (fun (loc,id) env -> push_name_env ntnvars (Variable,[],[],[]) env (loc,Name id)) il env in
- let na = alias_of_pat cp in
+ let na = alias_of_pat (List.hd disjpat) in
let ienv = Name.fold_right Id.Set.remove na env.ids in
- let id = Namegen.next_name_away_with_default "pat" (alias_of_pat cp) ienv in
+ let id = Namegen.next_name_away_with_default "pat" na ienv in
let na = (loc, Name id) in
- env,((cp,il),id),na
+ env,((disjpat,il),id),na
let intern_local_binder_aux ?(global_level=false) intern ntnvars (env,bl) = function
| CLocalAssum(nal,bk,ty) ->
@@ -470,11 +469,11 @@ let intern_local_binder_aux ?(global_level=false) intern ntnvars (env,bl) = func
| Some ty -> ty
| None -> CAst.make ?loc @@ CHole(None,Misctypes.IntroAnonymous,None)
in
- let env, ((cp,il),id),na = intern_cases_pattern_as_binder ?loc ntnvars env p in
+ let env, ((disjpat,il),id),na = intern_cases_pattern_as_binder ?loc ntnvars env p in
let bk = Default Explicit in
let _, bl' = intern_assumption intern ntnvars env [na] bk tyc in
let _,(_,bk,t) = List.hd bl' in
- (env, (DAst.make ?loc @@ GLocalPattern((cp,List.map snd il),id,bk,t)) :: bl)
+ (env, (DAst.make ?loc @@ GLocalPattern((disjpat,List.map snd il),id,bk,t)) :: bl)
let intern_generalization intern env ntnvars loc bk ak c =
let c = intern {env with unb = true} c in
@@ -518,9 +517,11 @@ let rec expand_binders ?loc mk bl c =
expand_binders ?loc mk bl (DAst.make ?loc @@ GLetIn (n, b, oty, c))
| GLocalAssum (n, bk, t) ->
expand_binders ?loc mk bl (mk ?loc (n,bk,t) c)
- | GLocalPattern ((pat,ids), id, bk, ty) ->
+ | GLocalPattern ((disjpat,ids), id, bk, ty) ->
let tm = DAst.make ?loc (GVar id) in
- let c = DAst.make ?loc @@ GCases (Misctypes.LetPatternStyle, None, [tm,(Anonymous,None)], [loc,(ids,[pat], c)]) in
+ (* Distribute the disjunctive patterns over the shared right-hand side *)
+ let eqnl = List.map (fun pat -> (loc,(ids,[pat],c))) disjpat in
+ let c = DAst.make ?loc @@ GCases (Misctypes.LetPatternStyle, None, [tm,(Anonymous,None)], eqnl) in
expand_binders ?loc mk bl (mk ?loc (Name id,Explicit,ty) c)
(**********************************************************************)
@@ -543,26 +544,32 @@ let find_fresh_name renaming (terms,termlists,binders,binderlists) avoid id =
(* TODO binders *)
next_ident_away_from id (fun id -> Id.Set.mem id fvs3)
+let is_var store pat =
+ match DAst.get pat with
+ | PatVar na -> store na; true
+ | _ -> false
+
let traverse_binder intern_pat ntnvars (terms,_,binders,_ as subst) avoid (renaming,env) = function
| Anonymous -> (renaming,env), None, Anonymous
| Name id ->
+ let store,get = set_temporary_memory () in
try
(* We instantiate binder name with patterns which may be parsed as terms *)
let pat = coerce_to_cases_pattern_expr (fst (Id.Map.find id terms)) in
- let env,((pat,ids),id),na = intern_pat ntnvars env pat in
- let pat, na = match DAst.get pat with
- | PatVar na -> None, na
- | _ -> Some ((List.map snd ids,pat),id), snd na in
+ let env,((disjpat,ids),id),na = intern_pat ntnvars env pat in
+ let pat, na = match disjpat with
+ | [pat] when is_var store pat -> let na = get () in None, na
+ | _ -> Some ((List.map snd ids,disjpat),id), snd na in
(renaming,env), pat, na
with Not_found ->
try
(* Trying to associate a pattern *)
let pat,scopes = Id.Map.find id binders in
let env = set_env_scopes env scopes in
- let env,((pat,ids),id),na = intern_pat ntnvars env pat in
- let pat, na = match DAst.get pat with
- | PatVar na -> None, na
- | _ -> Some ((List.map snd ids,pat),id), snd na in
+ let env,((disjpat,ids),id),na = intern_pat ntnvars env pat in
+ let pat, na = match disjpat with
+ | [pat] when is_var store pat -> let na = get () in None, na
+ | _ -> Some ((List.map snd ids,disjpat),id), snd na in
(renaming,env), pat, na
with Not_found ->
(* Binders not bound in the notation do not capture variables *)
@@ -582,10 +589,16 @@ type binder_action =
let dmap_with_loc f n =
CAst.map_with_loc (fun ?loc c -> f ?loc (DAst.get_thunk c)) n
+let error_cannot_coerce_wildcard_term ?loc () =
+ user_err ?loc Pp.(str "Cannot turn \"_\" into a term.")
+
+let error_cannot_coerce_disjunctive_pattern_term ?loc () =
+ user_err ?loc Pp.(str "Cannot turn a disjunctive pattern into a term.")
+
let terms_of_binders bl =
let rec term_of_pat pt = dmap_with_loc (fun ?loc -> function
| PatVar (Name id) -> CRef (Ident (loc,id), None)
- | PatVar (Anonymous) -> user_err Pp.(str "Cannot turn \"_\" into a term.")
+ | PatVar (Anonymous) -> error_cannot_coerce_wildcard_term ?loc ()
| PatCstr (c,l,_) ->
let r = Qualid (loc,qualid_of_path (path_of_global (ConstructRef c))) in
let hole = CAst.make ?loc @@ CHole (None,Misctypes.IntroAnonymous,None) in
@@ -599,7 +612,8 @@ let terms_of_binders bl =
| GLocalDef (Name id,_,_,_) -> extract_variables l
| GLocalDef (Anonymous,_,_,_)
| GLocalAssum (Anonymous,_,_) -> user_err Pp.(str "Cannot turn \"_\" into a term.")
- | GLocalPattern ((u,_),_,_,_) -> term_of_pat u :: extract_variables l
+ | GLocalPattern (([u],_),_,_,_) -> term_of_pat u :: extract_variables l
+ | GLocalPattern ((_,_),_,_,_) -> error_cannot_coerce_disjunctive_pattern_term ?loc ()
end
| [] -> [] in
extract_variables bl
@@ -676,8 +690,10 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c =
in
let mk_env' (c, (tmp_scope, subscopes)) =
let nenv = {env with tmp_scope; scopes = subscopes @ env.scopes} in
- let _,((pat,_),_),_ = intern_pat ntnvars nenv c in
- (glob_constr_of_cases_pattern pat, None)
+ let _,((disjpat,_),_),_ = intern_pat ntnvars nenv c in
+ match disjpat with
+ | [pat] -> (glob_constr_of_cases_pattern pat, None)
+ | _ -> error_cannot_coerce_disjunctive_pattern_term ?loc:c.CAst.loc ()
in
let terms = Id.Map.map mk_env terms in
let binders = Id.Map.map mk_env' binders in
@@ -708,14 +724,14 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c =
(* Two special cases to keep binder name synchronous with BinderType *)
| NProd (na,NHole(Evar_kinds.BinderType na',naming,arg),c')
when Name.equal na na' ->
- let subinfos,pat,na = traverse_binder intern_pat ntnvars subst avoid subinfos na in
+ let subinfos,disjpat,na = traverse_binder intern_pat ntnvars subst avoid subinfos na in
let ty = DAst.make ?loc @@ GHole (Evar_kinds.BinderType na,naming,arg) in
- DAst.make ?loc @@ GProd (na,Explicit,ty,Option.fold_right apply_cases_pattern pat (aux subst' subinfos c'))
+ DAst.make ?loc @@ GProd (na,Explicit,ty,Option.fold_right apply_cases_pattern disjpat (aux subst' subinfos c'))
| NLambda (na,NHole(Evar_kinds.BinderType na',naming,arg),c')
when Name.equal na na' ->
- let subinfos,pat,na = traverse_binder intern_pat ntnvars subst avoid subinfos na in
+ let subinfos,disjpat,na = traverse_binder intern_pat ntnvars subst avoid subinfos na in
let ty = DAst.make ?loc @@ GHole (Evar_kinds.BinderType na,naming,arg) in
- DAst.make ?loc @@ GLambda (na,Explicit,ty,Option.fold_right apply_cases_pattern pat (aux subst' subinfos c'))
+ DAst.make ?loc @@ GLambda (na,Explicit,ty,Option.fold_right apply_cases_pattern disjpat (aux subst' subinfos c'))
| t ->
glob_constr_of_notation_constr_with_binders ?loc
(traverse_binder intern_pat ntnvars subst avoid) (aux subst') subinfos t
@@ -730,11 +746,13 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c =
try
let pat,scopes = Id.Map.find id binders in
let env = set_env_scopes env scopes in
- (* We deactivate the check on hidden parameters *)
- (* since we are only interested in the pattern as a term *)
+ (* We deactivate impls to avoid the check on hidden parameters *)
+ (* and since we are only interested in the pattern as a term *)
let env = reset_hidden_inductive_implicit_test env in
- let env,((pat,ids),id),na = intern_pat ntnvars env pat in
- glob_constr_of_cases_pattern pat
+ let env,((disjpat,ids),id),na = intern_pat ntnvars env pat in
+ match disjpat with
+ | [pat] -> glob_constr_of_cases_pattern pat
+ | _ -> user_err Pp.(str "Cannot turn a disjunctive pattern into a term.")
with Not_found ->
try
match binderopt with
diff --git a/interp/notation_ops.ml b/interp/notation_ops.ml
index c44863791..81cdecf03 100644
--- a/interp/notation_ops.ml
+++ b/interp/notation_ops.ml
@@ -101,17 +101,24 @@ let name_to_ident = function
let to_id g e id = let e,na = g e (Name id) in e,name_to_ident na
+let product_of_cases_patterns patl =
+ List.fold_right (fun patl restl ->
+ List.flatten (List.map (fun p -> List.map (fun rest -> p::rest) restl) patl))
+ patl [[]]
+
let rec cases_pattern_fold_map ?loc g e = DAst.with_val (function
| PatVar na ->
- let e',pat,na' = g e na in
- e', (match pat with
- | None -> DAst.make ?loc @@ PatVar na'
- | Some ((_,pat),_) -> pat)
+ let e',disjpat,na' = g e na in
+ e', (match disjpat with
+ | None -> [DAst.make ?loc @@ PatVar na']
+ | Some ((_,disjpat),_) -> disjpat)
| PatCstr (cstr,patl,na) ->
- let e',pat,na' = g e na in
- if pat <> None then user_err (Pp.str "Unable to instantiate an \"as\" clause with a pattern.");
+ let e',disjpat,na' = g e na in
+ if disjpat <> None then user_err (Pp.str "Unable to instantiate an \"as\" clause with a pattern.");
let e',patl' = List.fold_left_map (cases_pattern_fold_map ?loc g) e patl in
- e', DAst.make ?loc @@ PatCstr (cstr,patl',na')
+ (* Distribute outwards the inner disjunctive patterns *)
+ let disjpatl' = product_of_cases_patterns patl' in
+ e', List.map (fun patl' -> DAst.make ?loc @@ PatCstr (cstr,patl',na')) disjpatl'
)
let subst_binder_type_vars l = function
@@ -141,14 +148,14 @@ let rec subst_glob_vars l gc = DAst.map (function
let ldots_var = Id.of_string ".."
let protect g e na =
- let e',pat,na = g e na in
- if pat <> None then user_err (Pp.str "Unsupported substitution of an arbitrary pattern.");
+ let e',disjpat,na = g e na in
+ if disjpat <> None then user_err (Pp.str "Unsupported substitution of an arbitrary pattern.");
e',na
-let apply_cases_pattern ?loc ((ids,pat),id) c =
+let apply_cases_pattern ?loc ((ids,disjpat),id) c =
let tm = DAst.make ?loc (GVar id) in
- DAst.make ?loc @@
- GCases (LetPatternStyle, None, [tm,(Anonymous,None)], [loc,(ids,[pat], c)])
+ let eqns = List.map (fun pat -> (loc,(ids,[pat],c))) disjpat in
+ DAst.make ?loc @@ GCases (LetPatternStyle, None, [tm,(Anonymous,None)], eqns)
let glob_constr_of_notation_constr_with_binders ?loc g f e nc =
let lt x = DAst.make ?loc x in lt @@ match nc with
@@ -167,14 +174,14 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc =
let outerl = (ldots_var,inner)::(if swap then [x, lt @@ GVar y] else []) in
DAst.get (subst_glob_vars outerl it)
| NLambda (na,ty,c) ->
- let e',pat,na = g e na in GLambda (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) pat (f e' c))
+ let e',disjpat,na = g e na in GLambda (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c))
| NProd (na,ty,c) ->
- let e',pat,na = g e na in GProd (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) pat (f e' c))
+ let e',disjpat,na = g e na in GProd (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c))
| NLetIn (na,b,t,c) ->
- let e',pat,na = g e na in
- (match pat with
+ let e',disjpat,na = g e na in
+ (match disjpat with
| None -> GLetIn (na,f e b,Option.map (f e) t,f e' c)
- | Some pat -> DAst.get (apply_cases_pattern ?loc pat (f e' c)))
+ | Some disjpat -> DAst.get (apply_cases_pattern ?loc disjpat (f e' c)))
| NCases (sty,rtntypopt,tml,eqnl) ->
let e',tml' = List.fold_right (fun (tm,(na,t)) (e',tml') ->
let e',t' = match t with
@@ -183,15 +190,16 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc =
let e',nal' = List.fold_right (fun na (e',nal) ->
let e',na' = protect g e' na in
e',na'::nal) nal (e',[]) in
- e',Some (Loc.tag ?loc (ind,nal')) in
+ e',Some (Loc.tag ?loc (ind,nal')) in
let e',na' = protect g e' na in
- (e',(f e tm,(na',t'))::tml')) tml (e,[]) in
- let fold (idl,e) na = let (e,pat,na) = g e na in ((Name.cons na idl,e),pat,na) in
+ (e',(f e tm,(na',t'))::tml')) tml (e,[]) in
+ let fold (idl,e) na = let (e,disjpat,na) = g e na in ((Name.cons na idl,e),disjpat,na) in
let eqnl' = List.map (fun (patl,rhs) ->
- let ((idl,e),patl) =
- List.fold_left_map (cases_pattern_fold_map ?loc fold) ([],e) patl in
- Loc.tag (idl,patl,f e rhs)) eqnl in
- GCases (sty,Option.map (f e') rtntypopt,tml',eqnl')
+ let ((idl,e),patl) =
+ List.fold_left_map (cases_pattern_fold_map ?loc fold) ([],e) patl in
+ let disjpatl = product_of_cases_patterns patl in
+ List.map (fun patl -> Loc.tag (idl,patl,f e rhs)) disjpatl) eqnl in
+ GCases (sty,Option.map (f e') rtntypopt,tml',List.flatten eqnl')
| NLetTuple (nal,(na,po),b,c) ->
let e',nal = List.fold_left_map (protect g) e nal in
let e'',na = protect g e na in
@@ -806,8 +814,8 @@ let unify_binder_upto alp b b' =
| GLocalDef (na,bk,c,t), GLocalDef (na',bk',c',t') ->
let alp, na = unify_name_upto alp na na' in
alp, DAst.make ?loc @@ GLocalDef (na, unify_binding_kind bk bk', unify_term alp c c', unify_opt_term alp t t')
- | GLocalPattern ((p,ids),id,bk,t), GLocalPattern ((p',_),_,bk',t') ->
- let alp, p = unify_pat_upto alp p p' in
+ | GLocalPattern ((disjpat,ids),id,bk,t), GLocalPattern ((disjpat',_),_,bk',t') when List.length disjpat = List.length disjpat' ->
+ let alp, p = List.fold_left2_map unify_pat_upto alp disjpat disjpat' in
alp, DAst.make ?loc @@ GLocalPattern ((p,ids), id, unify_binding_kind bk bk', unify_term alp t t')
| _ -> raise No_match
@@ -840,9 +848,9 @@ let unify_term_binder alp c = DAst.(map (fun b' ->
match DAst.get c, b' with
| GVar id, GLocalAssum (na', bk', t') ->
GLocalAssum (unify_id alp id na', bk', t')
- | _, GLocalPattern ((p',ids), id, bk', t') ->
+ | _, GLocalPattern (([p'],ids), id, bk', t') ->
let p = pat_binder_of_term c in
- GLocalPattern ((unify_pat alp p p',ids), id, bk', t')
+ GLocalPattern (([unify_pat alp p p'],ids), id, bk', t')
| _ -> raise No_match))
let rec unify_terms_binders alp cl bl' =
@@ -895,23 +903,23 @@ let bind_binding_as_term_env alp (terms,termlists,binders,binderlists as sigma)
let pat = try force_cases_pattern (cases_pattern_of_glob_constr Anonymous c) with Not_found -> raise No_match in
try
(* If already bound to a binder, unify the term and the binder *)
- let pat' = Id.List.assoc var binders in
- let pat'' = unify_pat alp pat pat' in
- if pat' == pat'' then sigma
+ let patl' = Id.List.assoc var binders in
+ let patl'' = List.map2 (unify_pat alp) [pat] patl' in
+ if patl' == patl'' then sigma
else
let sigma = (terms,termlists,Id.List.remove_assoc var binders,binderlists) in
- add_binding_env alp sigma var pat''
- with Not_found -> add_binding_env alp sigma var pat
+ add_binding_env alp sigma var patl''
+ with Not_found -> add_binding_env alp sigma var [pat]
-let bind_binding_env alp (terms,termlists,binders,binderlists as sigma) var pat =
+let bind_binding_env alp (terms,termlists,binders,binderlists as sigma) var disjpat =
try
(* If already bound to a binder possibly *)
(* generating an alpha-renaming from unifying the new binder *)
- let pat' = Id.List.assoc var binders in
- let alp, pat = unify_pat_upto alp pat pat' in
+ let disjpat' = Id.List.assoc var binders in
+ let alp, disjpat = List.fold_left2_map unify_pat_upto alp disjpat disjpat' in
let sigma = (terms,termlists,Id.List.remove_assoc var binders,binderlists) in
- alp, add_binding_env alp sigma var pat
- with Not_found -> alp, add_binding_env alp sigma var pat
+ alp, add_binding_env alp sigma var disjpat
+ with Not_found -> alp, add_binding_env alp sigma var disjpat
let bind_bindinglist_env alp (terms,termlists,binders,binderlists as sigma) var bl =
let bl = List.rev bl in
@@ -955,7 +963,7 @@ let match_opt f sigma t1 t2 = match (t1,t2) with
let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with
| (na1,Name id2) when is_onlybinding_meta id2 metas ->
- bind_binding_env alp sigma id2 (DAst.make (PatVar na1))
+ bind_binding_env alp sigma id2 [DAst.make (PatVar na1)]
| (Name id1,Name id2) when is_term_meta id2 metas ->
(* We let the non-binding occurrence define the rhs and hence reason up to *)
(* alpha-conversion for the given occurrence of the name (see #4592)) *)
@@ -970,7 +978,7 @@ let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with
let rec match_cases_pattern_binders metas (alp,sigma as acc) pat1 pat2 =
match DAst.get pat1, DAst.get pat2 with
| _, PatVar (Name id2) when is_onlybinding_pattern_like_meta id2 metas ->
- bind_binding_env alp sigma id2 pat1
+ bind_binding_env alp sigma id2 [pat1]
| PatVar na1, PatVar na2 -> match_names metas acc na1 na2
| PatCstr (c1,patl1,na1), PatCstr (c2,patl2,na2)
when eq_constructor c1 c2 && Int.equal (List.length patl1) (List.length patl2) ->
@@ -1205,18 +1213,27 @@ and match_binders u alp metas na1 na2 sigma b1 b2 =
and match_extended_binders ?loc isprod u alp metas na1 na2 bk t sigma b1 b2 =
(* Match binders which can be substituted by a pattern *)
+ let store, get = set_temporary_memory () in
match na1, DAst.get b1, na2 with
(* Matching individual binders as part of a recursive pattern *)
- | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(ids,[cp],b1))]), Name id
- when is_gvar p e && is_bindinglist_meta id metas ->
- let cp = if occur_glob_constr p b1 then set_pat_alias p cp else cp in
- let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalPattern ((cp,ids),p,bk,t)] in
+ | Name p, GCases (LetPatternStyle,None,[(e,_)],(_::_ as eqns)), Name id
+ when is_gvar p e && is_bindinglist_meta id metas && List.length (store (Detyping.factorize_eqns eqns)) = 1 ->
+ (match get () with
+ | [(_,(ids,disj_of_patl,b1))] ->
+ let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in
+ let disjpat = if occur_glob_constr p b1 then List.map (set_pat_alias p) disjpat else disjpat in
+ let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalPattern ((disjpat,ids),p,bk,t)] in
match_in u alp metas sigma b1 b2
- | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(_,[cp],b1))]), Name id
- when is_gvar p e && is_onlybinding_pattern_like_meta id metas ->
- let cp = if occur_glob_constr p b1 then set_pat_alias p cp else cp in
- let alp,sigma = bind_binding_env alp sigma id cp in
+ | _ -> assert false)
+ | Name p, GCases (LetPatternStyle,None,[(e,_)],(_::_ as eqns)), Name id
+ when is_gvar p e && is_onlybinding_pattern_like_meta id metas && List.length (store (Detyping.factorize_eqns eqns)) = 1 ->
+ (match get () with
+ | [(_,(ids,disj_of_patl,b1))] ->
+ let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in
+ let disjpat = if occur_glob_constr p b1 then List.map (set_pat_alias p) disjpat else disjpat in
+ let alp,sigma = bind_binding_env alp sigma id disjpat in
match_in u alp metas sigma b1 b2
+ | _ -> assert false)
| _, _, Name id when is_bindinglist_meta id metas && (not isprod || na1 != Anonymous)->
let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalAssum (na1,bk,t)] in
match_in u alp metas sigma b1 b2
@@ -1243,8 +1260,11 @@ let match_notation_constr u c (metas,pat) =
let term = try Id.List.assoc x terms with Not_found -> raise No_match in
((term, scl)::terms',termlists',binders',binderlists')
| NtnTypeBinder NtnParsedAsConstr ->
- let v = glob_constr_of_cases_pattern (Id.List.assoc x binders) in
- ((v,scl)::terms',termlists',binders',binderlists')
+ (match Id.List.assoc x binders with
+ | [pat] ->
+ let v = glob_constr_of_cases_pattern pat in
+ ((v,scl)::terms',termlists',binders',binderlists')
+ | _ -> raise No_match)
| NtnTypeBinder (NtnParsedAsIdent | NtnParsedAsPattern) ->
(terms',termlists',(Id.List.assoc x binders,scl)::binders',binderlists')
| NtnTypeConstrList ->
diff --git a/interp/notation_ops.mli b/interp/notation_ops.mli
index 1a2dfc9ca..746f52e48 100644
--- a/interp/notation_ops.mli
+++ b/interp/notation_ops.mli
@@ -34,10 +34,10 @@ val notation_constr_of_glob_constr : notation_interp_env ->
(** Re-interpret a notation as a [glob_constr], taking care of binders *)
val apply_cases_pattern : ?loc:Loc.t ->
- (Id.t list * cases_pattern) * Id.t -> glob_constr -> glob_constr
+ (Id.t list * cases_pattern_disjunction) * Id.t -> glob_constr -> glob_constr
val glob_constr_of_notation_constr_with_binders : ?loc:Loc.t ->
- ('a -> Name.t -> 'a * ((Id.t list * cases_pattern) * Id.t) option * Name.t) ->
+ ('a -> Name.t -> 'a * ((Id.t list * cases_pattern_disjunction) * Id.t) option * Name.t) ->
('a -> notation_constr -> glob_constr) ->
'a -> notation_constr -> glob_constr
@@ -52,7 +52,7 @@ exception No_match
val match_notation_constr : bool -> 'a glob_constr_g -> interpretation ->
('a glob_constr_g * subscopes) list * ('a glob_constr_g list * subscopes) list *
- ('a cases_pattern_g * subscopes) list *
+ ('a cases_pattern_disjunction_g * subscopes) list *
('a extended_glob_local_binder_g list * subscopes) list
val match_notation_constr_cases_pattern :