diff options
Diffstat (limited to 'interp/notation_ops.ml')
-rw-r--r-- | interp/notation_ops.ml | 120 |
1 files changed, 70 insertions, 50 deletions
diff --git a/interp/notation_ops.ml b/interp/notation_ops.ml index c44863791..81cdecf03 100644 --- a/interp/notation_ops.ml +++ b/interp/notation_ops.ml @@ -101,17 +101,24 @@ let name_to_ident = function let to_id g e id = let e,na = g e (Name id) in e,name_to_ident na +let product_of_cases_patterns patl = + List.fold_right (fun patl restl -> + List.flatten (List.map (fun p -> List.map (fun rest -> p::rest) restl) patl)) + patl [[]] + let rec cases_pattern_fold_map ?loc g e = DAst.with_val (function | PatVar na -> - let e',pat,na' = g e na in - e', (match pat with - | None -> DAst.make ?loc @@ PatVar na' - | Some ((_,pat),_) -> pat) + let e',disjpat,na' = g e na in + e', (match disjpat with + | None -> [DAst.make ?loc @@ PatVar na'] + | Some ((_,disjpat),_) -> disjpat) | PatCstr (cstr,patl,na) -> - let e',pat,na' = g e na in - if pat <> None then user_err (Pp.str "Unable to instantiate an \"as\" clause with a pattern."); + let e',disjpat,na' = g e na in + if disjpat <> None then user_err (Pp.str "Unable to instantiate an \"as\" clause with a pattern."); let e',patl' = List.fold_left_map (cases_pattern_fold_map ?loc g) e patl in - e', DAst.make ?loc @@ PatCstr (cstr,patl',na') + (* Distribute outwards the inner disjunctive patterns *) + let disjpatl' = product_of_cases_patterns patl' in + e', List.map (fun patl' -> DAst.make ?loc @@ PatCstr (cstr,patl',na')) disjpatl' ) let subst_binder_type_vars l = function @@ -141,14 +148,14 @@ let rec subst_glob_vars l gc = DAst.map (function let ldots_var = Id.of_string ".." let protect g e na = - let e',pat,na = g e na in - if pat <> None then user_err (Pp.str "Unsupported substitution of an arbitrary pattern."); + let e',disjpat,na = g e na in + if disjpat <> None then user_err (Pp.str "Unsupported substitution of an arbitrary pattern."); e',na -let apply_cases_pattern ?loc ((ids,pat),id) c = +let apply_cases_pattern ?loc ((ids,disjpat),id) c = let tm = DAst.make ?loc (GVar id) in - DAst.make ?loc @@ - GCases (LetPatternStyle, None, [tm,(Anonymous,None)], [loc,(ids,[pat], c)]) + let eqns = List.map (fun pat -> (loc,(ids,[pat],c))) disjpat in + DAst.make ?loc @@ GCases (LetPatternStyle, None, [tm,(Anonymous,None)], eqns) let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let lt x = DAst.make ?loc x in lt @@ match nc with @@ -167,14 +174,14 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let outerl = (ldots_var,inner)::(if swap then [x, lt @@ GVar y] else []) in DAst.get (subst_glob_vars outerl it) | NLambda (na,ty,c) -> - let e',pat,na = g e na in GLambda (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) pat (f e' c)) + let e',disjpat,na = g e na in GLambda (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c)) | NProd (na,ty,c) -> - let e',pat,na = g e na in GProd (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) pat (f e' c)) + let e',disjpat,na = g e na in GProd (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c)) | NLetIn (na,b,t,c) -> - let e',pat,na = g e na in - (match pat with + let e',disjpat,na = g e na in + (match disjpat with | None -> GLetIn (na,f e b,Option.map (f e) t,f e' c) - | Some pat -> DAst.get (apply_cases_pattern ?loc pat (f e' c))) + | Some disjpat -> DAst.get (apply_cases_pattern ?loc disjpat (f e' c))) | NCases (sty,rtntypopt,tml,eqnl) -> let e',tml' = List.fold_right (fun (tm,(na,t)) (e',tml') -> let e',t' = match t with @@ -183,15 +190,16 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let e',nal' = List.fold_right (fun na (e',nal) -> let e',na' = protect g e' na in e',na'::nal) nal (e',[]) in - e',Some (Loc.tag ?loc (ind,nal')) in + e',Some (Loc.tag ?loc (ind,nal')) in let e',na' = protect g e' na in - (e',(f e tm,(na',t'))::tml')) tml (e,[]) in - let fold (idl,e) na = let (e,pat,na) = g e na in ((Name.cons na idl,e),pat,na) in + (e',(f e tm,(na',t'))::tml')) tml (e,[]) in + let fold (idl,e) na = let (e,disjpat,na) = g e na in ((Name.cons na idl,e),disjpat,na) in let eqnl' = List.map (fun (patl,rhs) -> - let ((idl,e),patl) = - List.fold_left_map (cases_pattern_fold_map ?loc fold) ([],e) patl in - Loc.tag (idl,patl,f e rhs)) eqnl in - GCases (sty,Option.map (f e') rtntypopt,tml',eqnl') + let ((idl,e),patl) = + List.fold_left_map (cases_pattern_fold_map ?loc fold) ([],e) patl in + let disjpatl = product_of_cases_patterns patl in + List.map (fun patl -> Loc.tag (idl,patl,f e rhs)) disjpatl) eqnl in + GCases (sty,Option.map (f e') rtntypopt,tml',List.flatten eqnl') | NLetTuple (nal,(na,po),b,c) -> let e',nal = List.fold_left_map (protect g) e nal in let e'',na = protect g e na in @@ -806,8 +814,8 @@ let unify_binder_upto alp b b' = | GLocalDef (na,bk,c,t), GLocalDef (na',bk',c',t') -> let alp, na = unify_name_upto alp na na' in alp, DAst.make ?loc @@ GLocalDef (na, unify_binding_kind bk bk', unify_term alp c c', unify_opt_term alp t t') - | GLocalPattern ((p,ids),id,bk,t), GLocalPattern ((p',_),_,bk',t') -> - let alp, p = unify_pat_upto alp p p' in + | GLocalPattern ((disjpat,ids),id,bk,t), GLocalPattern ((disjpat',_),_,bk',t') when List.length disjpat = List.length disjpat' -> + let alp, p = List.fold_left2_map unify_pat_upto alp disjpat disjpat' in alp, DAst.make ?loc @@ GLocalPattern ((p,ids), id, unify_binding_kind bk bk', unify_term alp t t') | _ -> raise No_match @@ -840,9 +848,9 @@ let unify_term_binder alp c = DAst.(map (fun b' -> match DAst.get c, b' with | GVar id, GLocalAssum (na', bk', t') -> GLocalAssum (unify_id alp id na', bk', t') - | _, GLocalPattern ((p',ids), id, bk', t') -> + | _, GLocalPattern (([p'],ids), id, bk', t') -> let p = pat_binder_of_term c in - GLocalPattern ((unify_pat alp p p',ids), id, bk', t') + GLocalPattern (([unify_pat alp p p'],ids), id, bk', t') | _ -> raise No_match)) let rec unify_terms_binders alp cl bl' = @@ -895,23 +903,23 @@ let bind_binding_as_term_env alp (terms,termlists,binders,binderlists as sigma) let pat = try force_cases_pattern (cases_pattern_of_glob_constr Anonymous c) with Not_found -> raise No_match in try (* If already bound to a binder, unify the term and the binder *) - let pat' = Id.List.assoc var binders in - let pat'' = unify_pat alp pat pat' in - if pat' == pat'' then sigma + let patl' = Id.List.assoc var binders in + let patl'' = List.map2 (unify_pat alp) [pat] patl' in + if patl' == patl'' then sigma else let sigma = (terms,termlists,Id.List.remove_assoc var binders,binderlists) in - add_binding_env alp sigma var pat'' - with Not_found -> add_binding_env alp sigma var pat + add_binding_env alp sigma var patl'' + with Not_found -> add_binding_env alp sigma var [pat] -let bind_binding_env alp (terms,termlists,binders,binderlists as sigma) var pat = +let bind_binding_env alp (terms,termlists,binders,binderlists as sigma) var disjpat = try (* If already bound to a binder possibly *) (* generating an alpha-renaming from unifying the new binder *) - let pat' = Id.List.assoc var binders in - let alp, pat = unify_pat_upto alp pat pat' in + let disjpat' = Id.List.assoc var binders in + let alp, disjpat = List.fold_left2_map unify_pat_upto alp disjpat disjpat' in let sigma = (terms,termlists,Id.List.remove_assoc var binders,binderlists) in - alp, add_binding_env alp sigma var pat - with Not_found -> alp, add_binding_env alp sigma var pat + alp, add_binding_env alp sigma var disjpat + with Not_found -> alp, add_binding_env alp sigma var disjpat let bind_bindinglist_env alp (terms,termlists,binders,binderlists as sigma) var bl = let bl = List.rev bl in @@ -955,7 +963,7 @@ let match_opt f sigma t1 t2 = match (t1,t2) with let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with | (na1,Name id2) when is_onlybinding_meta id2 metas -> - bind_binding_env alp sigma id2 (DAst.make (PatVar na1)) + bind_binding_env alp sigma id2 [DAst.make (PatVar na1)] | (Name id1,Name id2) when is_term_meta id2 metas -> (* We let the non-binding occurrence define the rhs and hence reason up to *) (* alpha-conversion for the given occurrence of the name (see #4592)) *) @@ -970,7 +978,7 @@ let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with let rec match_cases_pattern_binders metas (alp,sigma as acc) pat1 pat2 = match DAst.get pat1, DAst.get pat2 with | _, PatVar (Name id2) when is_onlybinding_pattern_like_meta id2 metas -> - bind_binding_env alp sigma id2 pat1 + bind_binding_env alp sigma id2 [pat1] | PatVar na1, PatVar na2 -> match_names metas acc na1 na2 | PatCstr (c1,patl1,na1), PatCstr (c2,patl2,na2) when eq_constructor c1 c2 && Int.equal (List.length patl1) (List.length patl2) -> @@ -1205,18 +1213,27 @@ and match_binders u alp metas na1 na2 sigma b1 b2 = and match_extended_binders ?loc isprod u alp metas na1 na2 bk t sigma b1 b2 = (* Match binders which can be substituted by a pattern *) + let store, get = set_temporary_memory () in match na1, DAst.get b1, na2 with (* Matching individual binders as part of a recursive pattern *) - | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(ids,[cp],b1))]), Name id - when is_gvar p e && is_bindinglist_meta id metas -> - let cp = if occur_glob_constr p b1 then set_pat_alias p cp else cp in - let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalPattern ((cp,ids),p,bk,t)] in + | Name p, GCases (LetPatternStyle,None,[(e,_)],(_::_ as eqns)), Name id + when is_gvar p e && is_bindinglist_meta id metas && List.length (store (Detyping.factorize_eqns eqns)) = 1 -> + (match get () with + | [(_,(ids,disj_of_patl,b1))] -> + let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in + let disjpat = if occur_glob_constr p b1 then List.map (set_pat_alias p) disjpat else disjpat in + let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalPattern ((disjpat,ids),p,bk,t)] in match_in u alp metas sigma b1 b2 - | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(_,[cp],b1))]), Name id - when is_gvar p e && is_onlybinding_pattern_like_meta id metas -> - let cp = if occur_glob_constr p b1 then set_pat_alias p cp else cp in - let alp,sigma = bind_binding_env alp sigma id cp in + | _ -> assert false) + | Name p, GCases (LetPatternStyle,None,[(e,_)],(_::_ as eqns)), Name id + when is_gvar p e && is_onlybinding_pattern_like_meta id metas && List.length (store (Detyping.factorize_eqns eqns)) = 1 -> + (match get () with + | [(_,(ids,disj_of_patl,b1))] -> + let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in + let disjpat = if occur_glob_constr p b1 then List.map (set_pat_alias p) disjpat else disjpat in + let alp,sigma = bind_binding_env alp sigma id disjpat in match_in u alp metas sigma b1 b2 + | _ -> assert false) | _, _, Name id when is_bindinglist_meta id metas && (not isprod || na1 != Anonymous)-> let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalAssum (na1,bk,t)] in match_in u alp metas sigma b1 b2 @@ -1243,8 +1260,11 @@ let match_notation_constr u c (metas,pat) = let term = try Id.List.assoc x terms with Not_found -> raise No_match in ((term, scl)::terms',termlists',binders',binderlists') | NtnTypeBinder NtnParsedAsConstr -> - let v = glob_constr_of_cases_pattern (Id.List.assoc x binders) in - ((v,scl)::terms',termlists',binders',binderlists') + (match Id.List.assoc x binders with + | [pat] -> + let v = glob_constr_of_cases_pattern pat in + ((v,scl)::terms',termlists',binders',binderlists') + | _ -> raise No_match) | NtnTypeBinder (NtnParsedAsIdent | NtnParsedAsPattern) -> (terms',termlists',(Id.List.assoc x binders,scl)::binders',binderlists') | NtnTypeConstrList -> |