From e4d93d1cef27d3a8c1e36139fc1e118730406f67 Mon Sep 17 00:00:00 2001 From: Hugo Herbelin Date: Thu, 17 Aug 2017 20:12:55 +0200 Subject: Adding general support for irrefutable disjunctive patterns. This now works not only for parsing of fun/forall (as in 8.6), but also for arbitraty notations with binders and for printing. --- interp/constrexpr_ops.ml | 4 ++ interp/constrexpr_ops.mli | 2 + interp/constrextern.ml | 34 ++++++++----- interp/constrintern.ml | 86 ++++++++++++++++++++------------- interp/notation_ops.ml | 120 +++++++++++++++++++++++++++------------------- interp/notation_ops.mli | 6 +-- 6 files changed, 153 insertions(+), 99 deletions(-) (limited to 'interp') diff --git a/interp/constrexpr_ops.ml b/interp/constrexpr_ops.ml index 4877bf271..8aca6e333 100644 --- a/interp/constrexpr_ops.ml +++ b/interp/constrexpr_ops.ml @@ -547,6 +547,10 @@ let coerce_to_name = function | { CAst.loc; _ } -> CErrors.user_err ?loc ~hdr:"coerce_to_name" (str "This expression should be a name.") +let mkCPatOr ?loc = function + | [pat] -> pat + | disjpat -> CAst.make ?loc @@ (CPatOr disjpat) + let mkAppPattern ?loc p lp = let open CAst in make ?loc @@ (match p.v with diff --git a/interp/constrexpr_ops.mli b/interp/constrexpr_ops.mli index 9a59d66f4..0b00b0e4d 100644 --- a/interp/constrexpr_ops.mli +++ b/interp/constrexpr_ops.mli @@ -54,6 +54,8 @@ val mkCLambdaN : ?loc:Loc.t -> local_binder_expr list -> constr_expr -> constr_e val mkCProdN : ?loc:Loc.t -> local_binder_expr list -> constr_expr -> constr_expr (** Same as [prod_constr_expr], with location *) +val mkCPatOr : ?loc:Loc.t -> cases_pattern_expr list -> cases_pattern_expr + val mkAppPattern : ?loc:Loc.t -> cases_pattern_expr -> cases_pattern_expr list -> cases_pattern_expr (** Apply a list of pattern arguments to a pattern *) diff --git a/interp/constrextern.ml b/interp/constrextern.ml index 67e19d125..9e18966b6 100644 --- a/interp/constrextern.ml +++ b/interp/constrextern.ml @@ -922,16 +922,21 @@ and extern_typ (_,scopes) = and sub_extern inctx (_,scopes) = extern inctx (None,scopes) and factorize_prod scopes vars na bk aty c = + let store, get = set_temporary_memory () in match na, DAst.get c with - | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],[(_,(_,[p],b))]) - when is_gvar id e -> - let p = if occur_glob_constr id b then set_pat_alias id p else p in + | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],(_::_ as eqns)) + when is_gvar id e && List.length (store (factorize_eqns eqns)) = 1 -> + (match get () with + | [(_,(ids,disj_of_patl,b))] -> + let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in + let disjpat = if occur_glob_constr id b then List.map (set_pat_alias id) disjpat else disjpat in let b = extern_typ scopes vars b in - let p = extern_cases_pattern_in_scope scopes vars p in + let p = mkCPatOr (List.map (extern_cases_pattern_in_scope scopes vars) disjpat) in let binder = CLocalPattern (c.loc,(p,None)) in (match b.v with | CProdN (bl,b) -> CProdN (binder::bl,b) | _ -> CProdN ([binder],b)) + | _ -> assert false) | _, _ -> let c = extern_typ scopes vars c in match na, c.v with @@ -945,16 +950,21 @@ and factorize_prod scopes vars na bk aty c = CProdN ([CLocalAssum([Loc.tag na],Default bk,aty)],c) and factorize_lambda inctx scopes vars na bk aty c = + let store, get = set_temporary_memory () in match na, DAst.get c with - | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],[(_,(_,[p],b))]) - when is_gvar id e -> - let p = if occur_glob_constr id b then set_pat_alias id p else p in + | Name id, GCases (LetPatternStyle, None, [(e,(Anonymous,None))],(_::_ as eqns)) + when is_gvar id e && List.length (store (factorize_eqns eqns)) = 1 -> + (match get () with + | [(_,(ids,disj_of_patl,b))] -> + let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in + let disjpat = if occur_glob_constr id b then List.map (set_pat_alias id) disjpat else disjpat in let b = sub_extern inctx scopes vars b in - let p = extern_cases_pattern_in_scope scopes vars p in + let p = mkCPatOr (List.map (extern_cases_pattern_in_scope scopes vars) disjpat) in let binder = CLocalPattern (c.loc,(p,None)) in (match b.v with | CLambdaN (bl,b) -> CLambdaN (binder::bl,b) | _ -> CLambdaN ([binder],b)) + | _ -> assert false) | _, _ -> let c = sub_extern inctx scopes vars c in match c.v with @@ -994,7 +1004,7 @@ and extern_local_binder scopes vars = function | GLocalPattern ((p,_),_,bk,ty) -> let ty = if !Flags.raw_print then Some (extern_typ scopes vars ty) else None in - let p = extern_cases_pattern vars p in + let p = mkCPatOr (List.map (extern_cases_pattern vars) p) in let (assums,ids,l) = extern_local_binder scopes vars l in (assums,ids, CLocalPattern(Loc.tag @@ (p,ty)) :: l) @@ -1066,10 +1076,10 @@ and extern_notation (tmp_scope,scopes as allscopes) vars t = function termlists in let bl = List.map (fun (bl,(scopt,scl)) -> - extern_cases_pattern_in_scope (scopt,scl@scopes') vars bl) + mkCPatOr (List.map (extern_cases_pattern_in_scope (scopt,scl@scopes') vars) bl)) binders in - let bll = - List.map (fun (bl,(scopt,scl)) -> + let bll = + List.map (fun (bl,(scopt,scl)) -> pi3 (extern_local_binder (scopt,scl@scopes') vars bl)) binderlists in insert_delimiters (make_notation loc ntn (l,ll,bl,bll)) key) diff --git a/interp/constrintern.ml b/interp/constrintern.ml index 63cf66bdd..379d09e89 100644 --- a/interp/constrintern.ml +++ b/interp/constrintern.ml @@ -441,20 +441,19 @@ let intern_letin_binder intern ntnvars env ((loc,na as locna),def,ty) = (na,Explicit,term,ty)) let intern_cases_pattern_as_binder ?loc ntnvars env p = - let il,cp = - match !intern_cases_pattern_fwd ntnvars (None,env.scopes) p with - | (il, [(subst,cp)]) -> - if not (Id.Map.equal Id.equal subst Id.Map.empty) then - user_err ?loc (str "Unsupported nested \"as\" clause."); - il,cp - | _ -> assert false + let il,disjpat = + let (il, subst_disjpat) = !intern_cases_pattern_fwd ntnvars (None,env.scopes) p in + let substl,disjpat = List.split subst_disjpat in + if not (List.for_all (fun subst -> Id.Map.equal Id.equal subst Id.Map.empty) substl) then + user_err ?loc (str "Unsupported nested \"as\" clause."); + il,disjpat in let env = List.fold_right (fun (loc,id) env -> push_name_env ntnvars (Variable,[],[],[]) env (loc,Name id)) il env in - let na = alias_of_pat cp in + let na = alias_of_pat (List.hd disjpat) in let ienv = Name.fold_right Id.Set.remove na env.ids in - let id = Namegen.next_name_away_with_default "pat" (alias_of_pat cp) ienv in + let id = Namegen.next_name_away_with_default "pat" na ienv in let na = (loc, Name id) in - env,((cp,il),id),na + env,((disjpat,il),id),na let intern_local_binder_aux ?(global_level=false) intern ntnvars (env,bl) = function | CLocalAssum(nal,bk,ty) -> @@ -470,11 +469,11 @@ let intern_local_binder_aux ?(global_level=false) intern ntnvars (env,bl) = func | Some ty -> ty | None -> CAst.make ?loc @@ CHole(None,Misctypes.IntroAnonymous,None) in - let env, ((cp,il),id),na = intern_cases_pattern_as_binder ?loc ntnvars env p in + let env, ((disjpat,il),id),na = intern_cases_pattern_as_binder ?loc ntnvars env p in let bk = Default Explicit in let _, bl' = intern_assumption intern ntnvars env [na] bk tyc in let _,(_,bk,t) = List.hd bl' in - (env, (DAst.make ?loc @@ GLocalPattern((cp,List.map snd il),id,bk,t)) :: bl) + (env, (DAst.make ?loc @@ GLocalPattern((disjpat,List.map snd il),id,bk,t)) :: bl) let intern_generalization intern env ntnvars loc bk ak c = let c = intern {env with unb = true} c in @@ -518,9 +517,11 @@ let rec expand_binders ?loc mk bl c = expand_binders ?loc mk bl (DAst.make ?loc @@ GLetIn (n, b, oty, c)) | GLocalAssum (n, bk, t) -> expand_binders ?loc mk bl (mk ?loc (n,bk,t) c) - | GLocalPattern ((pat,ids), id, bk, ty) -> + | GLocalPattern ((disjpat,ids), id, bk, ty) -> let tm = DAst.make ?loc (GVar id) in - let c = DAst.make ?loc @@ GCases (Misctypes.LetPatternStyle, None, [tm,(Anonymous,None)], [loc,(ids,[pat], c)]) in + (* Distribute the disjunctive patterns over the shared right-hand side *) + let eqnl = List.map (fun pat -> (loc,(ids,[pat],c))) disjpat in + let c = DAst.make ?loc @@ GCases (Misctypes.LetPatternStyle, None, [tm,(Anonymous,None)], eqnl) in expand_binders ?loc mk bl (mk ?loc (Name id,Explicit,ty) c) (**********************************************************************) @@ -543,26 +544,32 @@ let find_fresh_name renaming (terms,termlists,binders,binderlists) avoid id = (* TODO binders *) next_ident_away_from id (fun id -> Id.Set.mem id fvs3) +let is_var store pat = + match DAst.get pat with + | PatVar na -> store na; true + | _ -> false + let traverse_binder intern_pat ntnvars (terms,_,binders,_ as subst) avoid (renaming,env) = function | Anonymous -> (renaming,env), None, Anonymous | Name id -> + let store,get = set_temporary_memory () in try (* We instantiate binder name with patterns which may be parsed as terms *) let pat = coerce_to_cases_pattern_expr (fst (Id.Map.find id terms)) in - let env,((pat,ids),id),na = intern_pat ntnvars env pat in - let pat, na = match DAst.get pat with - | PatVar na -> None, na - | _ -> Some ((List.map snd ids,pat),id), snd na in + let env,((disjpat,ids),id),na = intern_pat ntnvars env pat in + let pat, na = match disjpat with + | [pat] when is_var store pat -> let na = get () in None, na + | _ -> Some ((List.map snd ids,disjpat),id), snd na in (renaming,env), pat, na with Not_found -> try (* Trying to associate a pattern *) let pat,scopes = Id.Map.find id binders in let env = set_env_scopes env scopes in - let env,((pat,ids),id),na = intern_pat ntnvars env pat in - let pat, na = match DAst.get pat with - | PatVar na -> None, na - | _ -> Some ((List.map snd ids,pat),id), snd na in + let env,((disjpat,ids),id),na = intern_pat ntnvars env pat in + let pat, na = match disjpat with + | [pat] when is_var store pat -> let na = get () in None, na + | _ -> Some ((List.map snd ids,disjpat),id), snd na in (renaming,env), pat, na with Not_found -> (* Binders not bound in the notation do not capture variables *) @@ -582,10 +589,16 @@ type binder_action = let dmap_with_loc f n = CAst.map_with_loc (fun ?loc c -> f ?loc (DAst.get_thunk c)) n +let error_cannot_coerce_wildcard_term ?loc () = + user_err ?loc Pp.(str "Cannot turn \"_\" into a term.") + +let error_cannot_coerce_disjunctive_pattern_term ?loc () = + user_err ?loc Pp.(str "Cannot turn a disjunctive pattern into a term.") + let terms_of_binders bl = let rec term_of_pat pt = dmap_with_loc (fun ?loc -> function | PatVar (Name id) -> CRef (Ident (loc,id), None) - | PatVar (Anonymous) -> user_err Pp.(str "Cannot turn \"_\" into a term.") + | PatVar (Anonymous) -> error_cannot_coerce_wildcard_term ?loc () | PatCstr (c,l,_) -> let r = Qualid (loc,qualid_of_path (path_of_global (ConstructRef c))) in let hole = CAst.make ?loc @@ CHole (None,Misctypes.IntroAnonymous,None) in @@ -599,7 +612,8 @@ let terms_of_binders bl = | GLocalDef (Name id,_,_,_) -> extract_variables l | GLocalDef (Anonymous,_,_,_) | GLocalAssum (Anonymous,_,_) -> user_err Pp.(str "Cannot turn \"_\" into a term.") - | GLocalPattern ((u,_),_,_,_) -> term_of_pat u :: extract_variables l + | GLocalPattern (([u],_),_,_,_) -> term_of_pat u :: extract_variables l + | GLocalPattern ((_,_),_,_,_) -> error_cannot_coerce_disjunctive_pattern_term ?loc () end | [] -> [] in extract_variables bl @@ -676,8 +690,10 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c = in let mk_env' (c, (tmp_scope, subscopes)) = let nenv = {env with tmp_scope; scopes = subscopes @ env.scopes} in - let _,((pat,_),_),_ = intern_pat ntnvars nenv c in - (glob_constr_of_cases_pattern pat, None) + let _,((disjpat,_),_),_ = intern_pat ntnvars nenv c in + match disjpat with + | [pat] -> (glob_constr_of_cases_pattern pat, None) + | _ -> error_cannot_coerce_disjunctive_pattern_term ?loc:c.CAst.loc () in let terms = Id.Map.map mk_env terms in let binders = Id.Map.map mk_env' binders in @@ -708,14 +724,14 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c = (* Two special cases to keep binder name synchronous with BinderType *) | NProd (na,NHole(Evar_kinds.BinderType na',naming,arg),c') when Name.equal na na' -> - let subinfos,pat,na = traverse_binder intern_pat ntnvars subst avoid subinfos na in + let subinfos,disjpat,na = traverse_binder intern_pat ntnvars subst avoid subinfos na in let ty = DAst.make ?loc @@ GHole (Evar_kinds.BinderType na,naming,arg) in - DAst.make ?loc @@ GProd (na,Explicit,ty,Option.fold_right apply_cases_pattern pat (aux subst' subinfos c')) + DAst.make ?loc @@ GProd (na,Explicit,ty,Option.fold_right apply_cases_pattern disjpat (aux subst' subinfos c')) | NLambda (na,NHole(Evar_kinds.BinderType na',naming,arg),c') when Name.equal na na' -> - let subinfos,pat,na = traverse_binder intern_pat ntnvars subst avoid subinfos na in + let subinfos,disjpat,na = traverse_binder intern_pat ntnvars subst avoid subinfos na in let ty = DAst.make ?loc @@ GHole (Evar_kinds.BinderType na,naming,arg) in - DAst.make ?loc @@ GLambda (na,Explicit,ty,Option.fold_right apply_cases_pattern pat (aux subst' subinfos c')) + DAst.make ?loc @@ GLambda (na,Explicit,ty,Option.fold_right apply_cases_pattern disjpat (aux subst' subinfos c')) | t -> glob_constr_of_notation_constr_with_binders ?loc (traverse_binder intern_pat ntnvars subst avoid) (aux subst') subinfos t @@ -730,11 +746,13 @@ let instantiate_notation_constr loc intern intern_pat ntnvars subst infos c = try let pat,scopes = Id.Map.find id binders in let env = set_env_scopes env scopes in - (* We deactivate the check on hidden parameters *) - (* since we are only interested in the pattern as a term *) + (* We deactivate impls to avoid the check on hidden parameters *) + (* and since we are only interested in the pattern as a term *) let env = reset_hidden_inductive_implicit_test env in - let env,((pat,ids),id),na = intern_pat ntnvars env pat in - glob_constr_of_cases_pattern pat + let env,((disjpat,ids),id),na = intern_pat ntnvars env pat in + match disjpat with + | [pat] -> glob_constr_of_cases_pattern pat + | _ -> user_err Pp.(str "Cannot turn a disjunctive pattern into a term.") with Not_found -> try match binderopt with diff --git a/interp/notation_ops.ml b/interp/notation_ops.ml index c44863791..81cdecf03 100644 --- a/interp/notation_ops.ml +++ b/interp/notation_ops.ml @@ -101,17 +101,24 @@ let name_to_ident = function let to_id g e id = let e,na = g e (Name id) in e,name_to_ident na +let product_of_cases_patterns patl = + List.fold_right (fun patl restl -> + List.flatten (List.map (fun p -> List.map (fun rest -> p::rest) restl) patl)) + patl [[]] + let rec cases_pattern_fold_map ?loc g e = DAst.with_val (function | PatVar na -> - let e',pat,na' = g e na in - e', (match pat with - | None -> DAst.make ?loc @@ PatVar na' - | Some ((_,pat),_) -> pat) + let e',disjpat,na' = g e na in + e', (match disjpat with + | None -> [DAst.make ?loc @@ PatVar na'] + | Some ((_,disjpat),_) -> disjpat) | PatCstr (cstr,patl,na) -> - let e',pat,na' = g e na in - if pat <> None then user_err (Pp.str "Unable to instantiate an \"as\" clause with a pattern."); + let e',disjpat,na' = g e na in + if disjpat <> None then user_err (Pp.str "Unable to instantiate an \"as\" clause with a pattern."); let e',patl' = List.fold_left_map (cases_pattern_fold_map ?loc g) e patl in - e', DAst.make ?loc @@ PatCstr (cstr,patl',na') + (* Distribute outwards the inner disjunctive patterns *) + let disjpatl' = product_of_cases_patterns patl' in + e', List.map (fun patl' -> DAst.make ?loc @@ PatCstr (cstr,patl',na')) disjpatl' ) let subst_binder_type_vars l = function @@ -141,14 +148,14 @@ let rec subst_glob_vars l gc = DAst.map (function let ldots_var = Id.of_string ".." let protect g e na = - let e',pat,na = g e na in - if pat <> None then user_err (Pp.str "Unsupported substitution of an arbitrary pattern."); + let e',disjpat,na = g e na in + if disjpat <> None then user_err (Pp.str "Unsupported substitution of an arbitrary pattern."); e',na -let apply_cases_pattern ?loc ((ids,pat),id) c = +let apply_cases_pattern ?loc ((ids,disjpat),id) c = let tm = DAst.make ?loc (GVar id) in - DAst.make ?loc @@ - GCases (LetPatternStyle, None, [tm,(Anonymous,None)], [loc,(ids,[pat], c)]) + let eqns = List.map (fun pat -> (loc,(ids,[pat],c))) disjpat in + DAst.make ?loc @@ GCases (LetPatternStyle, None, [tm,(Anonymous,None)], eqns) let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let lt x = DAst.make ?loc x in lt @@ match nc with @@ -167,14 +174,14 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let outerl = (ldots_var,inner)::(if swap then [x, lt @@ GVar y] else []) in DAst.get (subst_glob_vars outerl it) | NLambda (na,ty,c) -> - let e',pat,na = g e na in GLambda (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) pat (f e' c)) + let e',disjpat,na = g e na in GLambda (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c)) | NProd (na,ty,c) -> - let e',pat,na = g e na in GProd (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) pat (f e' c)) + let e',disjpat,na = g e na in GProd (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c)) | NLetIn (na,b,t,c) -> - let e',pat,na = g e na in - (match pat with + let e',disjpat,na = g e na in + (match disjpat with | None -> GLetIn (na,f e b,Option.map (f e) t,f e' c) - | Some pat -> DAst.get (apply_cases_pattern ?loc pat (f e' c))) + | Some disjpat -> DAst.get (apply_cases_pattern ?loc disjpat (f e' c))) | NCases (sty,rtntypopt,tml,eqnl) -> let e',tml' = List.fold_right (fun (tm,(na,t)) (e',tml') -> let e',t' = match t with @@ -183,15 +190,16 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let e',nal' = List.fold_right (fun na (e',nal) -> let e',na' = protect g e' na in e',na'::nal) nal (e',[]) in - e',Some (Loc.tag ?loc (ind,nal')) in + e',Some (Loc.tag ?loc (ind,nal')) in let e',na' = protect g e' na in - (e',(f e tm,(na',t'))::tml')) tml (e,[]) in - let fold (idl,e) na = let (e,pat,na) = g e na in ((Name.cons na idl,e),pat,na) in + (e',(f e tm,(na',t'))::tml')) tml (e,[]) in + let fold (idl,e) na = let (e,disjpat,na) = g e na in ((Name.cons na idl,e),disjpat,na) in let eqnl' = List.map (fun (patl,rhs) -> - let ((idl,e),patl) = - List.fold_left_map (cases_pattern_fold_map ?loc fold) ([],e) patl in - Loc.tag (idl,patl,f e rhs)) eqnl in - GCases (sty,Option.map (f e') rtntypopt,tml',eqnl') + let ((idl,e),patl) = + List.fold_left_map (cases_pattern_fold_map ?loc fold) ([],e) patl in + let disjpatl = product_of_cases_patterns patl in + List.map (fun patl -> Loc.tag (idl,patl,f e rhs)) disjpatl) eqnl in + GCases (sty,Option.map (f e') rtntypopt,tml',List.flatten eqnl') | NLetTuple (nal,(na,po),b,c) -> let e',nal = List.fold_left_map (protect g) e nal in let e'',na = protect g e na in @@ -806,8 +814,8 @@ let unify_binder_upto alp b b' = | GLocalDef (na,bk,c,t), GLocalDef (na',bk',c',t') -> let alp, na = unify_name_upto alp na na' in alp, DAst.make ?loc @@ GLocalDef (na, unify_binding_kind bk bk', unify_term alp c c', unify_opt_term alp t t') - | GLocalPattern ((p,ids),id,bk,t), GLocalPattern ((p',_),_,bk',t') -> - let alp, p = unify_pat_upto alp p p' in + | GLocalPattern ((disjpat,ids),id,bk,t), GLocalPattern ((disjpat',_),_,bk',t') when List.length disjpat = List.length disjpat' -> + let alp, p = List.fold_left2_map unify_pat_upto alp disjpat disjpat' in alp, DAst.make ?loc @@ GLocalPattern ((p,ids), id, unify_binding_kind bk bk', unify_term alp t t') | _ -> raise No_match @@ -840,9 +848,9 @@ let unify_term_binder alp c = DAst.(map (fun b' -> match DAst.get c, b' with | GVar id, GLocalAssum (na', bk', t') -> GLocalAssum (unify_id alp id na', bk', t') - | _, GLocalPattern ((p',ids), id, bk', t') -> + | _, GLocalPattern (([p'],ids), id, bk', t') -> let p = pat_binder_of_term c in - GLocalPattern ((unify_pat alp p p',ids), id, bk', t') + GLocalPattern (([unify_pat alp p p'],ids), id, bk', t') | _ -> raise No_match)) let rec unify_terms_binders alp cl bl' = @@ -895,23 +903,23 @@ let bind_binding_as_term_env alp (terms,termlists,binders,binderlists as sigma) let pat = try force_cases_pattern (cases_pattern_of_glob_constr Anonymous c) with Not_found -> raise No_match in try (* If already bound to a binder, unify the term and the binder *) - let pat' = Id.List.assoc var binders in - let pat'' = unify_pat alp pat pat' in - if pat' == pat'' then sigma + let patl' = Id.List.assoc var binders in + let patl'' = List.map2 (unify_pat alp) [pat] patl' in + if patl' == patl'' then sigma else let sigma = (terms,termlists,Id.List.remove_assoc var binders,binderlists) in - add_binding_env alp sigma var pat'' - with Not_found -> add_binding_env alp sigma var pat + add_binding_env alp sigma var patl'' + with Not_found -> add_binding_env alp sigma var [pat] -let bind_binding_env alp (terms,termlists,binders,binderlists as sigma) var pat = +let bind_binding_env alp (terms,termlists,binders,binderlists as sigma) var disjpat = try (* If already bound to a binder possibly *) (* generating an alpha-renaming from unifying the new binder *) - let pat' = Id.List.assoc var binders in - let alp, pat = unify_pat_upto alp pat pat' in + let disjpat' = Id.List.assoc var binders in + let alp, disjpat = List.fold_left2_map unify_pat_upto alp disjpat disjpat' in let sigma = (terms,termlists,Id.List.remove_assoc var binders,binderlists) in - alp, add_binding_env alp sigma var pat - with Not_found -> alp, add_binding_env alp sigma var pat + alp, add_binding_env alp sigma var disjpat + with Not_found -> alp, add_binding_env alp sigma var disjpat let bind_bindinglist_env alp (terms,termlists,binders,binderlists as sigma) var bl = let bl = List.rev bl in @@ -955,7 +963,7 @@ let match_opt f sigma t1 t2 = match (t1,t2) with let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with | (na1,Name id2) when is_onlybinding_meta id2 metas -> - bind_binding_env alp sigma id2 (DAst.make (PatVar na1)) + bind_binding_env alp sigma id2 [DAst.make (PatVar na1)] | (Name id1,Name id2) when is_term_meta id2 metas -> (* We let the non-binding occurrence define the rhs and hence reason up to *) (* alpha-conversion for the given occurrence of the name (see #4592)) *) @@ -970,7 +978,7 @@ let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with let rec match_cases_pattern_binders metas (alp,sigma as acc) pat1 pat2 = match DAst.get pat1, DAst.get pat2 with | _, PatVar (Name id2) when is_onlybinding_pattern_like_meta id2 metas -> - bind_binding_env alp sigma id2 pat1 + bind_binding_env alp sigma id2 [pat1] | PatVar na1, PatVar na2 -> match_names metas acc na1 na2 | PatCstr (c1,patl1,na1), PatCstr (c2,patl2,na2) when eq_constructor c1 c2 && Int.equal (List.length patl1) (List.length patl2) -> @@ -1205,18 +1213,27 @@ and match_binders u alp metas na1 na2 sigma b1 b2 = and match_extended_binders ?loc isprod u alp metas na1 na2 bk t sigma b1 b2 = (* Match binders which can be substituted by a pattern *) + let store, get = set_temporary_memory () in match na1, DAst.get b1, na2 with (* Matching individual binders as part of a recursive pattern *) - | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(ids,[cp],b1))]), Name id - when is_gvar p e && is_bindinglist_meta id metas -> - let cp = if occur_glob_constr p b1 then set_pat_alias p cp else cp in - let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalPattern ((cp,ids),p,bk,t)] in + | Name p, GCases (LetPatternStyle,None,[(e,_)],(_::_ as eqns)), Name id + when is_gvar p e && is_bindinglist_meta id metas && List.length (store (Detyping.factorize_eqns eqns)) = 1 -> + (match get () with + | [(_,(ids,disj_of_patl,b1))] -> + let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in + let disjpat = if occur_glob_constr p b1 then List.map (set_pat_alias p) disjpat else disjpat in + let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalPattern ((disjpat,ids),p,bk,t)] in match_in u alp metas sigma b1 b2 - | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(_,[cp],b1))]), Name id - when is_gvar p e && is_onlybinding_pattern_like_meta id metas -> - let cp = if occur_glob_constr p b1 then set_pat_alias p cp else cp in - let alp,sigma = bind_binding_env alp sigma id cp in + | _ -> assert false) + | Name p, GCases (LetPatternStyle,None,[(e,_)],(_::_ as eqns)), Name id + when is_gvar p e && is_onlybinding_pattern_like_meta id metas && List.length (store (Detyping.factorize_eqns eqns)) = 1 -> + (match get () with + | [(_,(ids,disj_of_patl,b1))] -> + let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in + let disjpat = if occur_glob_constr p b1 then List.map (set_pat_alias p) disjpat else disjpat in + let alp,sigma = bind_binding_env alp sigma id disjpat in match_in u alp metas sigma b1 b2 + | _ -> assert false) | _, _, Name id when is_bindinglist_meta id metas && (not isprod || na1 != Anonymous)-> let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalAssum (na1,bk,t)] in match_in u alp metas sigma b1 b2 @@ -1243,8 +1260,11 @@ let match_notation_constr u c (metas,pat) = let term = try Id.List.assoc x terms with Not_found -> raise No_match in ((term, scl)::terms',termlists',binders',binderlists') | NtnTypeBinder NtnParsedAsConstr -> - let v = glob_constr_of_cases_pattern (Id.List.assoc x binders) in - ((v,scl)::terms',termlists',binders',binderlists') + (match Id.List.assoc x binders with + | [pat] -> + let v = glob_constr_of_cases_pattern pat in + ((v,scl)::terms',termlists',binders',binderlists') + | _ -> raise No_match) | NtnTypeBinder (NtnParsedAsIdent | NtnParsedAsPattern) -> (terms',termlists',(Id.List.assoc x binders,scl)::binders',binderlists') | NtnTypeConstrList -> diff --git a/interp/notation_ops.mli b/interp/notation_ops.mli index 1a2dfc9ca..746f52e48 100644 --- a/interp/notation_ops.mli +++ b/interp/notation_ops.mli @@ -34,10 +34,10 @@ val notation_constr_of_glob_constr : notation_interp_env -> (** Re-interpret a notation as a [glob_constr], taking care of binders *) val apply_cases_pattern : ?loc:Loc.t -> - (Id.t list * cases_pattern) * Id.t -> glob_constr -> glob_constr + (Id.t list * cases_pattern_disjunction) * Id.t -> glob_constr -> glob_constr val glob_constr_of_notation_constr_with_binders : ?loc:Loc.t -> - ('a -> Name.t -> 'a * ((Id.t list * cases_pattern) * Id.t) option * Name.t) -> + ('a -> Name.t -> 'a * ((Id.t list * cases_pattern_disjunction) * Id.t) option * Name.t) -> ('a -> notation_constr -> glob_constr) -> 'a -> notation_constr -> glob_constr @@ -52,7 +52,7 @@ exception No_match val match_notation_constr : bool -> 'a glob_constr_g -> interpretation -> ('a glob_constr_g * subscopes) list * ('a glob_constr_g list * subscopes) list * - ('a cases_pattern_g * subscopes) list * + ('a cases_pattern_disjunction_g * subscopes) list * ('a extended_glob_local_binder_g list * subscopes) list val match_notation_constr_cases_pattern : -- cgit v1.2.3