diff options
Diffstat (limited to 'interp/notation_ops.ml')
-rw-r--r-- | interp/notation_ops.ml | 841 |
1 files changed, 458 insertions, 383 deletions
diff --git a/interp/notation_ops.ml b/interp/notation_ops.ml index 326d05cba..c65f4785e 100644 --- a/interp/notation_ops.ml +++ b/interp/notation_ops.ml @@ -42,9 +42,9 @@ let rec eq_notation_constr (vars1,vars2 as vars) t1 t2 = match t1, t2 with Name.equal na1 na2 && (eq_notation_constr vars) t1 t2 && (eq_notation_constr vars) u1 u2 | NProd (na1, t1, u1), NProd (na2, t2, u2) -> Name.equal na1 na2 && (eq_notation_constr vars) t1 t2 && (eq_notation_constr vars) u1 u2 -| NBinderList (i1, j1, t1, u1), NBinderList (i2, j2, t2, u2) -> +| NBinderList (i1, j1, t1, u1, b1), NBinderList (i2, j2, t2, u2, b2) -> Id.equal i1 i2 && Id.equal j1 j2 && (eq_notation_constr vars) t1 t2 && - (eq_notation_constr vars) u1 u2 + (eq_notation_constr vars) u1 u2 && b1 == b2 | NLetIn (na1, b1, t1, u1), NLetIn (na2, b2, t2, u2) -> Name.equal na1 na2 && eq_notation_constr vars b1 b2 && Option.equal (eq_notation_constr vars) t1 t2 && (eq_notation_constr vars) u1 u2 @@ -101,13 +101,24 @@ let name_to_ident = function let to_id g e id = let e,na = g e (Name id) in e,name_to_ident na +let product_of_cases_patterns patl = + List.fold_right (fun patl restl -> + List.flatten (List.map (fun p -> List.map (fun rest -> p::rest) restl) patl)) + patl [[]] + let rec cases_pattern_fold_map ?loc g e = DAst.with_val (function | PatVar na -> - let e',na' = g e na in e', DAst.make ?loc @@ PatVar na' + let e',disjpat,na' = g e na in + e', (match disjpat with + | None -> [DAst.make ?loc @@ PatVar na'] + | Some ((_,disjpat),_) -> disjpat) | PatCstr (cstr,patl,na) -> - let e',na' = g e na in + let e',disjpat,na' = g e na in + if disjpat <> None then user_err (Pp.str "Unable to instantiate an \"as\" clause with a pattern."); let e',patl' = List.fold_left_map (cases_pattern_fold_map ?loc g) e patl in - e', DAst.make ?loc @@ PatCstr (cstr,patl',na') + (* Distribute outwards the inner disjunctive patterns *) + let disjpatl' = product_of_cases_patterns patl' in + e', List.map (fun patl' -> DAst.make ?loc @@ PatCstr (cstr,patl',na')) disjpatl' ) let subst_binder_type_vars l = function @@ -136,6 +147,16 @@ let rec subst_glob_vars l gc = DAst.map (function let ldots_var = Id.of_string ".." +let protect g e na = + let e',disjpat,na = g e na in + if disjpat <> None then user_err (Pp.str "Unsupported substitution of an arbitrary pattern."); + e',na + +let apply_cases_pattern ?loc ((ids,disjpat),id) c = + let tm = DAst.make ?loc (GVar id) in + let eqns = List.map (fun pat -> (loc,(ids,[pat],c))) disjpat in + DAst.make ?loc @@ GCases (LetPatternStyle, None, [tm,(Anonymous,None)], eqns) + let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let lt x = DAst.make ?loc x in lt @@ match nc with | NVar id -> GVar id @@ -146,46 +167,51 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let inner = lt @@ GApp (lt @@ GVar (ldots_var),[subst_glob_vars innerl it]) in let outerl = (ldots_var,inner)::(if swap then [x, lt @@ GVar y] else []) in DAst.get (subst_glob_vars outerl it) - | NBinderList (x,y,iter,tail) -> + | NBinderList (x,y,iter,tail,swap) -> let t = f e tail in let it = f e iter in - let innerl = [(ldots_var,t);(x, lt @@ GVar y)] in + let innerl = (ldots_var,t)::(if swap then [] else [x, lt @@ GVar y]) in let inner = lt @@ GApp (lt @@ GVar ldots_var,[subst_glob_vars innerl it]) in - let outerl = [(ldots_var,inner)] in + let outerl = (ldots_var,inner)::(if swap then [x, lt @@ GVar y] else []) in DAst.get (subst_glob_vars outerl it) | NLambda (na,ty,c) -> - let e',na = g e na in GLambda (na,Explicit,f e ty,f e' c) + let e',disjpat,na = g e na in GLambda (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c)) | NProd (na,ty,c) -> - let e',na = g e na in GProd (na,Explicit,f e ty,f e' c) + let e',disjpat,na = g e na in GProd (na,Explicit,f e ty,Option.fold_right (apply_cases_pattern ?loc) disjpat (f e' c)) | NLetIn (na,b,t,c) -> - let e',na = g e na in GLetIn (na,f e b,Option.map (f e) t,f e' c) + let e',disjpat,na = g e na in + (match disjpat with + | None -> GLetIn (na,f e b,Option.map (f e) t,f e' c) + | Some disjpat -> DAst.get (apply_cases_pattern ?loc disjpat (f e' c))) | NCases (sty,rtntypopt,tml,eqnl) -> let e',tml' = List.fold_right (fun (tm,(na,t)) (e',tml') -> let e',t' = match t with | None -> e',None | Some (ind,nal) -> let e',nal' = List.fold_right (fun na (e',nal) -> - let e',na' = g e' na in e',na'::nal) nal (e',[]) in - e',Some (Loc.tag ?loc (ind,nal')) in - let e',na' = g e' na in - (e',(f e tm,(na',t'))::tml')) tml (e,[]) in - let fold (idl,e) na = let (e,na) = g e na in ((Name.cons na idl,e),na) in + let e',na' = protect g e' na in + e',na'::nal) nal (e',[]) in + e',Some (Loc.tag ?loc (ind,nal')) in + let e',na' = protect g e' na in + (e',(f e tm,(na',t'))::tml')) tml (e,[]) in + let fold (idl,e) na = let (e,disjpat,na) = g e na in ((Name.cons na idl,e),disjpat,na) in let eqnl' = List.map (fun (patl,rhs) -> - let ((idl,e),patl) = - List.fold_left_map (cases_pattern_fold_map ?loc fold) ([],e) patl in - Loc.tag (idl,patl,f e rhs)) eqnl in - GCases (sty,Option.map (f e') rtntypopt,tml',eqnl') + let ((idl,e),patl) = + List.fold_left_map (cases_pattern_fold_map ?loc fold) ([],e) patl in + let disjpatl = product_of_cases_patterns patl in + List.map (fun patl -> Loc.tag (idl,patl,f e rhs)) disjpatl) eqnl in + GCases (sty,Option.map (f e') rtntypopt,tml',List.flatten eqnl') | NLetTuple (nal,(na,po),b,c) -> - let e',nal = List.fold_left_map g e nal in - let e'',na = g e na in + let e',nal = List.fold_left_map (protect g) e nal in + let e'',na = protect g e na in GLetTuple (nal,(na,Option.map (f e'') po),f e b,f e' c) | NIf (c,(na,po),b1,b2) -> - let e',na = g e na in + let e',na = protect g e na in GIf (f e c,(na,Option.map (f e') po),f e b1,f e b2) | NRec (fk,idl,dll,tl,bl) -> - let e,dll = Array.fold_left_map (List.fold_left_map (fun e (na,oc,b) -> - let e,na = g e na in + let e,dll = Array.fold_left_map (List.fold_map (fun e (na,oc,b) -> + let e,na = protect g e na in (e,(na,Explicit,Option.map (f e) oc,f e b)))) e dll in - let e',idl = Array.fold_left_map (to_id g) e idl in + let e',idl = Array.fold_left_map (to_id (protect g)) e idl in GRec (fk,idl,dll,Array.map (f e) tl,Array.map (f e') bl) | NCast (c,k) -> GCast (f e c,Miscops.map_cast_type (f e) k) | NSort x -> GSort x @@ -195,13 +221,19 @@ let glob_constr_of_notation_constr_with_binders ?loc g f e nc = let glob_constr_of_notation_constr ?loc x = let rec aux () x = - glob_constr_of_notation_constr_with_binders ?loc (fun () id -> ((),id)) aux () x + glob_constr_of_notation_constr_with_binders ?loc (fun () id -> ((),None,id)) aux () x in aux () x (******************************************************************************) (* Translating a glob_constr into a notation, interpreting recursive patterns *) -let add_id r id = r := (id :: pi1 !r, pi2 !r, pi3 !r) +type found_variables = { + vars : Id.t list; + recursive_term_vars : (Id.t * Id.t) list; + recursive_binders_vars : (Id.t * Id.t) list; + } + +let add_id r id = r := { !r with vars = id :: (!r).vars } let add_name r = function Anonymous -> () | Name id -> add_id r id let is_gvar id c = match DAst.get c with @@ -245,13 +277,25 @@ let check_is_hole id t = match DAst.get t with GHole _ -> () | _ -> (strbrk "In recursive notation with binders, " ++ Id.print id ++ strbrk " is expected to come without type.") +let check_pair_matching ?loc x y x' y' revert revert' = + if not (Id.equal x x' && Id.equal y y' && revert = revert') then + let x,y = if revert then y,x else x,y in + let x',y' = if revert' then y',x' else x',y' in + (* This is a case where one would like to highlight two locations! *) + user_err ?loc + (strbrk "Found " ++ Id.print x ++ strbrk " matching " ++ Id.print y ++ + strbrk " while " ++ Id.print x' ++ strbrk " matching " ++ Id.print y' ++ + strbrk " was first found.") + let pair_equal eq1 eq2 (a,b) (a',b') = eq1 a a' && eq2 b b' +let mem_recursive_pair (x,y) l = List.mem_f (pair_equal Id.equal Id.equal) (x,y) l + type recursive_pattern_kind = -| RecursiveTerms of bool (* associativity *) -| RecursiveBinders of glob_constr * glob_constr +| RecursiveTerms of bool (* in reverse order *) +| RecursiveBinders of bool (* in reverse order *) -let compare_recursive_parts found f f' (iterator,subc) = +let compare_recursive_parts recvars found f f' (iterator,subc) = let diff = ref None in let terminator = ref None in let rec aux c1 c2 = match DAst.get c1, DAst.get c2 with @@ -270,24 +314,41 @@ let compare_recursive_parts found f f' (iterator,subc) = List.for_all2eq aux l1 l2 | _ -> mk_glob_constr_eq aux c1 c2 end - | GVar x, GVar y when not (Id.equal x y) -> + | GVar x, GVar y + when mem_recursive_pair (x,y) recvars || mem_recursive_pair (y,x) recvars -> (* We found the position where it differs *) - let lassoc = match !terminator with None -> false | Some _ -> true in - let x,y = if lassoc then y,x else x,y in + let revert = mem_recursive_pair (y,x) recvars in + let x,y = if revert then y,x else x,y in begin match !diff with | None -> - let () = diff := Some (x, y, RecursiveTerms lassoc) in + let () = diff := Some (x, y, RecursiveTerms revert) in + true + | Some (x', y', RecursiveTerms revert') + | Some (x', y', RecursiveBinders revert') -> + check_pair_matching ?loc:c1.CAst.loc x y x' y' revert revert'; true - | Some _ -> false end | GLambda (Name x,_,t_x,c), GLambda (Name y,_,t_y,term) - | GProd (Name x,_,t_x,c), GProd (Name y,_,t_y,term) when not (Id.equal x y) -> + | GProd (Name x,_,t_x,c), GProd (Name y,_,t_y,term) + when mem_recursive_pair (x,y) recvars || mem_recursive_pair (y,x) recvars -> (* We found a binding position where it differs *) + check_is_hole x t_x; + check_is_hole y t_y; + let revert = mem_recursive_pair (y,x) recvars in + let x,y = if revert then y,x else x,y in begin match !diff with | None -> - let () = diff := Some (x, y, RecursiveBinders (t_x,t_y)) in + let () = diff := Some (x, y, RecursiveBinders revert) in aux c term - | Some _ -> false + | Some (x', y', RecursiveBinders revert') -> + check_pair_matching ?loc:c1.CAst.loc x y x' y' revert revert'; + true + | Some (x', y', RecursiveTerms revert') -> + (* Recursive binders have precedence: they can be coerced to + terms but not reciprocally *) + check_pair_matching ?loc:c1.CAst.loc x y x' y' revert revert'; + let () = diff := Some (x, y, RecursiveBinders revert) in + true end | _ -> mk_glob_constr_eq aux c1 c2 in @@ -296,46 +357,36 @@ let compare_recursive_parts found f f' (iterator,subc) = | None -> let loc1 = loc_of_glob_constr iterator in let loc2 = loc_of_glob_constr (Option.get !terminator) in - (* Here, we would need a loc made of several parts ... *) - user_err ?loc:(subtract_loc loc1 loc2) + (* Here, we would need a loc made of several parts ... *) + user_err ?loc:(subtract_loc loc1 loc2) (str "Both ends of the recursive pattern are the same.") - | Some (x,y,RecursiveTerms lassoc) -> - let toadd,x,y,lassoc = - if List.mem_f (pair_equal Id.equal Id.equal) (x,y) (pi2 !found) || - List.mem_f (pair_equal Id.equal Id.equal) (x,y) (pi3 !found) - then - None,x,y,lassoc - else if List.mem_f (pair_equal Id.equal Id.equal) (y,x) (pi2 !found) || - List.mem_f (pair_equal Id.equal Id.equal) (y,x) (pi3 !found) - then - None,y,x,not lassoc - else - Some (x,y),x,y,lassoc in - let iterator = - f' (if lassoc then iterator - else subst_glob_vars [x, DAst.make @@ GVar y] iterator) in - (* found variables have been collected by compare_constr *) - found := (List.remove Id.equal y (pi1 !found), - Option.fold_right (fun a l -> a::l) toadd (pi2 !found), - pi3 !found); - NList (x,y,iterator,f (Option.get !terminator),lassoc) - | Some (x,y,RecursiveBinders (t_x,t_y)) -> - let iterator = f' (subst_glob_vars [x, DAst.make @@ GVar y] iterator) in - (* found have been collected by compare_constr *) - found := (List.remove Id.equal y (pi1 !found), pi2 !found, (x,y) :: pi3 !found); - check_is_hole x t_x; - check_is_hole y t_y; - NBinderList (x,y,iterator,f (Option.get !terminator)) + | Some (x,y,RecursiveTerms revert) -> + (* By arbitrary convention, we use the second variable of the pair + as the place-holder for the iterator *) + let iterator = + f' (if revert then iterator else subst_glob_vars [x, DAst.make @@ GVar y] iterator) in + (* found variables have been collected by compare_constr *) + found := { !found with vars = List.remove Id.equal y (!found).vars; + recursive_term_vars = List.add_set (pair_equal Id.equal Id.equal) (x,y) (!found).recursive_term_vars }; + NList (x,y,iterator,f (Option.get !terminator),revert) + | Some (x,y,RecursiveBinders revert) -> + let iterator = + f' (if revert then iterator else subst_glob_vars [x, DAst.make @@ GVar y] iterator) in + (* found have been collected by compare_constr *) + found := { !found with vars = List.remove Id.equal y (!found).vars; + recursive_binders_vars = List.add_set (pair_equal Id.equal Id.equal) (x,y) (!found).recursive_binders_vars }; + NBinderList (x,y,iterator,f (Option.get !terminator),revert) else raise Not_found -let notation_constr_and_vars_of_glob_constr a = - let found = ref ([],[],[]) in +let notation_constr_and_vars_of_glob_constr recvars a = + let found = ref { vars = []; recursive_term_vars = []; recursive_binders_vars = [] } in let has_ltac = ref false in + (* Turn a glob_constr into a notation_constr by first trying to find a recursive pattern *) let rec aux c = let keepfound = !found in (* n^2 complexity but small and done only once per notation *) - try compare_recursive_parts found aux aux' (split_at_recursive_part c) + try compare_recursive_parts recvars found aux aux' (split_at_recursive_part c) with Not_found -> found := keepfound; match DAst.get c with @@ -395,8 +446,9 @@ let notation_constr_and_vars_of_glob_constr a = (* Side effect *) t, !found, !has_ltac -let check_variables_and_reversibility nenv (found,foundrec,foundrecbinding) = - let injective = ref true in +let check_variables_and_reversibility nenv + { vars = found; recursive_term_vars = foundrec; recursive_binders_vars = foundrecbinding } = + let injective = ref [] in let recvars = nenv.ninterp_rec_vars in let fold _ y accu = Id.Set.add y accu in let useless_vars = Id.Map.fold fold recvars Id.Set.empty in @@ -419,33 +471,36 @@ let check_variables_and_reversibility nenv (found,foundrec,foundrecbinding) = user_err Pp.(str (Id.to_string x ^ " should not be bound in a recursive pattern of the right-hand side.")) - else injective := false + else injective := x :: !injective in let check_pair s x y where = - if not (List.mem_f (pair_equal Id.equal Id.equal) (x,y) where) then + if not (mem_recursive_pair (x,y) where) then user_err (strbrk "in the right-hand side, " ++ Id.print x ++ str " and " ++ Id.print y ++ strbrk " should appear in " ++ str s ++ str " position as part of a recursive pattern.") in let check_type x typ = match typ with - | NtnInternTypeConstr -> + | NtnInternTypeAny -> begin try check_pair "term" x (Id.Map.find x recvars) foundrec with Not_found -> check_bound x end - | NtnInternTypeBinder -> + | NtnInternTypeOnlyBinder -> begin try check_pair "binding" x (Id.Map.find x recvars) foundrecbinding with Not_found -> check_bound x - end - | NtnInternTypeIdent -> check_bound x in + end in Id.Map.iter check_type vars; - !injective + List.rev !injective let notation_constr_of_glob_constr nenv a = - let a, found, has_ltac = notation_constr_and_vars_of_glob_constr a in + let recvars = Id.Map.bindings nenv.ninterp_rec_vars in + let a, found, has_ltac = notation_constr_and_vars_of_glob_constr recvars a in let injective = check_variables_and_reversibility nenv found in - a, not has_ltac && injective + let status = if has_ltac then HasLtac else match injective with + | [] -> APrioriReversible + | l -> NonInjective l in + a, status (**********************************************************************) (* Substitution of kernel names, avoiding a list of bound identifiers *) @@ -501,11 +556,11 @@ let rec subst_notation_constr subst bound raw = if r1' == r1 && r2' == r2 then raw else NProd (n,r1',r2') - | NBinderList (id1,id2,r1,r2) -> + | NBinderList (id1,id2,r1,r2,b) -> let r1' = subst_notation_constr subst bound r1 and r2' = subst_notation_constr subst bound r2 in if r1' == r1 && r2' == r2 then raw else - NBinderList (id1,id2,r1',r2') + NBinderList (id1,id2,r1',r2',b) | NLetIn (n,r1,t,r2) -> let r1' = subst_notation_constr subst bound r1 in @@ -616,8 +671,20 @@ let is_term_meta id metas = try match Id.List.assoc id metas with _,(NtnTypeConstr | NtnTypeConstrList) -> true | _ -> false with Not_found -> false +let is_onlybinding_strict_meta id metas = + try match Id.List.assoc id metas with _,NtnTypeBinder (NtnParsedAsPattern true) -> true | _ -> false + with Not_found -> false + let is_onlybinding_meta id metas = - try match Id.List.assoc id metas with _,NtnTypeOnlyBinder -> true | _ -> false + try match Id.List.assoc id metas with _,NtnTypeBinder _ -> true | _ -> false + with Not_found -> false + +let is_onlybinding_pattern_like_meta isvar id metas = + try match Id.List.assoc id metas with + | _,NtnTypeBinder (NtnBinderParsedAsConstr + (Extend.AsIdentOrPattern | Extend.AsStrictPattern)) -> true + | _,NtnTypeBinder (NtnParsedAsPattern strict) -> not (strict && isvar) + | _ -> false with Not_found -> false let is_bindinglist_meta id metas = @@ -636,7 +703,7 @@ let alpha_rename alpmetas v = if alpmetas == [] then v else try rename_glob_vars alpmetas v with UnsoundRenaming -> raise No_match -let add_env (alp,alpmetas) (terms,onlybinders,termlists,binderlists) var v = +let add_env (alp,alpmetas) (terms,termlists,binders,binderlists) var v = (* Check that no capture of binding variables occur *) (* [alp] is used when matching a pattern "fun x => ... x ... ?var ... x ..." with an actual term "fun z => ... z ..." when "x" is not bound in the @@ -664,19 +731,49 @@ let add_env (alp,alpmetas) (terms,onlybinders,termlists,binderlists) var v = refinement *) let v = alpha_rename alpmetas v in (* TODO: handle the case of multiple occs in different scopes *) - ((var,v)::terms,onlybinders,termlists,binderlists) + ((var,v)::terms,termlists,binders,binderlists) -let add_termlist_env (alp,alpmetas) (terms,onlybinders,termlists,binderlists) var vl = +let add_termlist_env (alp,alpmetas) (terms,termlists,binders,binderlists) var vl = if List.exists (fun (id,_) -> List.exists (occur_glob_constr id) vl) alp then raise No_match; let vl = List.map (alpha_rename alpmetas) vl in - (terms,onlybinders,(var,vl)::termlists,binderlists) + (terms,(var,vl)::termlists,binders,binderlists) -let add_binding_env alp (terms,onlybinders,termlists,binderlists) var v = +let add_binding_env alp (terms,termlists,binders,binderlists) var v = (* TODO: handle the case of multiple occs in different scopes *) - (terms,(var,v)::onlybinders,termlists,binderlists) + (terms,termlists,(var,v)::binders,binderlists) -let add_bindinglist_env (terms,onlybinders,termlists,binderlists) x bl = - (terms,onlybinders,termlists,(x,bl)::binderlists) +let add_bindinglist_env (terms,termlists,binders,binderlists) x bl = + (terms,termlists,binders,(x,bl)::binderlists) + +let rec map_cases_pattern_name_left f = DAst.map (function + | PatVar na -> PatVar (f na) + | PatCstr (c,l,na) -> PatCstr (c,List.map_left (map_cases_pattern_name_left f) l,f na) + ) + +let rec fold_cases_pattern_eq f x p p' = + let loc = p.CAst.loc in + match DAst.get p, DAst.get p' with + | PatVar na, PatVar na' -> let x,na = f x na na' in x, DAst.make ?loc @@ PatVar na + | PatCstr (c,l,na), PatCstr (c',l',na') when eq_constructor c c' -> + let x,l = fold_cases_pattern_list_eq f x l l' in + let x,na = f x na na' in + x, DAst.make ?loc @@ PatCstr (c,l,na) + | _ -> failwith "Not equal" + +and fold_cases_pattern_list_eq f x pl pl' = match pl, pl' with + | [], [] -> x, [] + | p::pl, p'::pl' -> + let x, p = fold_cases_pattern_eq f x p p' in + let x, pl = fold_cases_pattern_list_eq f x pl pl' in + x, p :: pl + | _ -> assert false + +let rec cases_pattern_eq p1 p2 = match DAst.get p1, DAst.get p2 with +| PatVar na1, PatVar na2 -> Name.equal na1 na2 +| PatCstr (c1, pl1, na1), PatCstr (c2, pl2, na2) -> + eq_constructor c1 c2 && List.equal cases_pattern_eq pl1 pl2 && + Name.equal na1 na2 +| _ -> false let rec pat_binder_of_term t = DAst.map (function | GVar id -> PatVar (Name id) @@ -691,39 +788,111 @@ let rec pat_binder_of_term t = DAst.map (function | _ -> raise No_match ) t -let bind_term_env alp (terms,onlybinders,termlists,binderlists as sigma) var v = +let unify_name_upto alp na na' = + match na, na' with + | Anonymous, na' -> alp, na' + | na, Anonymous -> alp, na + | Name id, Name id' -> + if Id.equal id id' then alp, na' + else (fst alp,(id,id')::snd alp), na' + +let unify_pat_upto alp p p' = + try fold_cases_pattern_eq unify_name_upto alp p p' with Failure _ -> raise No_match + +let unify_term alp v v' = + match DAst.get v, DAst.get v' with + | GHole _, _ -> v' + | _, GHole _ -> v + | _, _ -> if glob_constr_eq (alpha_rename (snd alp) v) v' then v else raise No_match + +let unify_opt_term alp v v' = + match v, v' with + | Some t, Some t' -> Some (unify_term alp t t') + | (Some _ as x), None | None, (Some _ as x) -> x + | None, None -> None + +let unify_binding_kind bk bk' = if bk == bk' then bk' else raise No_match + +let unify_binder_upto alp b b' = + let loc, loc' = CAst.(b.loc, b'.loc) in + match DAst.get b, DAst.get b' with + | GLocalAssum (na,bk,t), GLocalAssum (na',bk',t') -> + let alp, na = unify_name_upto alp na na' in + alp, DAst.make ?loc @@ GLocalAssum (na, unify_binding_kind bk bk', unify_term alp t t') + | GLocalDef (na,bk,c,t), GLocalDef (na',bk',c',t') -> + let alp, na = unify_name_upto alp na na' in + alp, DAst.make ?loc @@ GLocalDef (na, unify_binding_kind bk bk', unify_term alp c c', unify_opt_term alp t t') + | GLocalPattern ((disjpat,ids),id,bk,t), GLocalPattern ((disjpat',_),_,bk',t') when List.length disjpat = List.length disjpat' -> + let alp, p = List.fold_left2_map unify_pat_upto alp disjpat disjpat' in + alp, DAst.make ?loc @@ GLocalPattern ((p,ids), id, unify_binding_kind bk bk', unify_term alp t t') + | _ -> raise No_match + +let rec unify_terms alp vl vl' = + match vl, vl' with + | [], [] -> [] + | v :: vl, v' :: vl' -> unify_term alp v v' :: unify_terms alp vl vl' + | _ -> raise No_match + +let rec unify_binders_upto alp bl bl' = + match bl, bl' with + | [], [] -> alp, [] + | b :: bl, b' :: bl' -> + let alp,b = unify_binder_upto alp b b' in + let alp,bl = unify_binders_upto alp bl bl' in + alp, b :: bl + | _ -> raise No_match + +let unify_id alp id na' = + match na' with + | Anonymous -> Name (rename_var (snd alp) id) + | Name id' -> + if Id.equal (rename_var (snd alp) id) id' then na' else raise No_match + +let unify_pat alp p p' = + if cases_pattern_eq (map_cases_pattern_name_left (Name.map (rename_var (snd alp))) p) p' then p' + else raise No_match + +let unify_term_binder alp c = DAst.(map (fun b' -> + match DAst.get c, b' with + | GVar id, GLocalAssum (na', bk', t') -> + GLocalAssum (unify_id alp id na', bk', t') + | _, GLocalPattern (([p'],ids), id, bk', t') -> + let p = pat_binder_of_term c in + GLocalPattern (([unify_pat alp p p'],ids), id, bk', t') + | _ -> raise No_match)) + +let rec unify_terms_binders alp cl bl' = + match cl, bl' with + | [], [] -> [] + | c :: cl, b' :: bl' -> + begin match DAst.get b' with + | GLocalDef ( _, _, _, t) -> unify_terms_binders alp cl bl' + | _ -> unify_term_binder alp c b' :: unify_terms_binders alp cl bl' + end + | _ -> raise No_match + +let bind_term_env alp (terms,termlists,binders,binderlists as sigma) var v = try + (* If already bound to a term, unify with the new term *) let v' = Id.List.assoc var terms in - match DAst.get v, DAst.get v' with - | GHole _, _ -> sigma - | _, GHole _ -> - let sigma = Id.List.remove_assoc var terms,onlybinders,termlists,binderlists in - add_env alp sigma var v - | _, _ -> - if glob_constr_eq (alpha_rename (snd alp) v) v' then sigma - else raise No_match + let v'' = unify_term alp v v' in + if v'' == v' then sigma else + let sigma = (Id.List.remove_assoc var terms,termlists,binders,binderlists) in + add_env alp sigma var v with Not_found -> add_env alp sigma var v -let bind_termlist_env alp (terms,onlybinders,termlists,binderlists as sigma) var vl = +let bind_termlist_env alp (terms,termlists,binders,binderlists as sigma) var vl = try + (* If already bound to a list of term, unify with the new terms *) let vl' = Id.List.assoc var termlists in - let unify_term v v' = - match DAst.get v, DAst.get v' with - | GHole _, _ -> v' - | _, GHole _ -> v - | _, _ -> if glob_constr_eq (alpha_rename (snd alp) v) v' then v' else raise No_match in - let rec unify vl vl' = - match vl, vl' with - | [], [] -> [] - | v :: vl, v' :: vl' -> unify_term v v' :: unify vl vl' - | _ -> raise No_match in - let vl = unify vl vl' in - let sigma = (terms,onlybinders,Id.List.remove_assoc var termlists,binderlists) in + let vl = unify_terms alp vl vl' in + let sigma = (terms,Id.List.remove_assoc var termlists,binders,binderlists) in add_termlist_env alp sigma var vl with Not_found -> add_termlist_env alp sigma var vl -let bind_term_as_binding_env alp (terms,onlybinders,termlists,binderlists as sigma) var id = +let bind_term_as_binding_env alp (terms,termlists,binders,binderlists as sigma) var id = try + (* If already bound to a term, unify the binder and the term *) match DAst.get (Id.List.assoc var terms) with | GVar id' -> (if not (Id.equal id id') then (fst alp,(id,id')::snd alp) else alp), @@ -735,142 +904,49 @@ let bind_term_as_binding_env alp (terms,onlybinders,termlists,binderlists as sig (* TODO: look at the consequences for alp *) alp, add_env alp sigma var (DAst.make @@ GVar id) -let bind_binding_as_term_env alp (terms,onlybinders,termlists,binderlists as sigma) var id = +let force_cases_pattern c = + DAst.make ?loc:c.CAst.loc (DAst.get c) + +let bind_binding_as_term_env alp (terms,termlists,binders,binderlists as sigma) var c = + let pat = try force_cases_pattern (cases_pattern_of_glob_constr Anonymous c) with Not_found -> raise No_match in try - let v' = Id.List.assoc var onlybinders in - match v' with - | Anonymous -> - (* Should not occur, since the term has to be bound upwards *) - let sigma = (terms,Id.List.remove_assoc var onlybinders,termlists,binderlists) in - add_binding_env alp sigma var (Name id) - | Name id' -> - if Id.equal (rename_var (snd alp) id) id' then sigma else raise No_match - with Not_found -> add_binding_env alp sigma var (Name id) - -let bind_binding_env alp (terms,onlybinders,termlists,binderlists as sigma) var v = + (* If already bound to a binder, unify the term and the binder *) + let patl' = Id.List.assoc var binders in + let patl'' = List.map2 (unify_pat alp) [pat] patl' in + if patl' == patl'' then sigma + else + let sigma = (terms,termlists,Id.List.remove_assoc var binders,binderlists) in + add_binding_env alp sigma var patl'' + with Not_found -> add_binding_env alp sigma var [pat] + +let bind_binding_env alp (terms,termlists,binders,binderlists as sigma) var disjpat = try - let v' = Id.List.assoc var onlybinders in - match v, v' with - | Anonymous, _ -> alp, sigma - | _, Anonymous -> - let sigma = (terms,Id.List.remove_assoc var onlybinders,termlists,binderlists) in - alp, add_binding_env alp sigma var v - | Name id1, Name id2 -> - if Id.equal id1 id2 then alp,sigma - else (fst alp,(id1,id2)::snd alp),sigma - with Not_found -> alp, add_binding_env alp sigma var v - -let rec map_cases_pattern_name_left f = DAst.map (function - | PatVar na -> PatVar (f na) - | PatCstr (c,l,na) -> PatCstr (c,List.map_left (map_cases_pattern_name_left f) l,f na) - ) - -let rec fold_cases_pattern_eq f x p p' = - let loc = p.CAst.loc in - match DAst.get p, DAst.get p' with - | PatVar na, PatVar na' -> let x,na = f x na na' in x, DAst.make ?loc @@ PatVar na - | PatCstr (c,l,na), PatCstr (c',l',na') when eq_constructor c c' -> - let x,l = fold_cases_pattern_list_eq f x l l' in - let x,na = f x na na' in - x, DAst.make ?loc @@ PatCstr (c,l,na) - | _ -> failwith "Not equal" - -and fold_cases_pattern_list_eq f x pl pl' = match pl, pl' with - | [], [] -> x, [] - | p::pl, p'::pl' -> - let x, p = fold_cases_pattern_eq f x p p' in - let x, pl = fold_cases_pattern_list_eq f x pl pl' in - x, p :: pl - | _ -> assert false - -let rec cases_pattern_eq p1 p2 = match DAst.get p1, DAst.get p2 with -| PatVar na1, PatVar na2 -> Name.equal na1 na2 -| PatCstr (c1, pl1, na1), PatCstr (c2, pl2, na2) -> - eq_constructor c1 c2 && List.equal cases_pattern_eq pl1 pl2 && - Name.equal na1 na2 -| _ -> false - -let bind_bindinglist_env alp (terms,onlybinders,termlists,binderlists as sigma) var bl = + (* If already bound to a binder possibly *) + (* generating an alpha-renaming from unifying the new binder *) + let disjpat' = Id.List.assoc var binders in + let alp, disjpat = List.fold_left2_map unify_pat_upto alp disjpat disjpat' in + let sigma = (terms,termlists,Id.List.remove_assoc var binders,binderlists) in + alp, add_binding_env alp sigma var disjpat + with Not_found -> alp, add_binding_env alp sigma var disjpat + +let bind_bindinglist_env alp (terms,termlists,binders,binderlists as sigma) var bl = let bl = List.rev bl in try + (* If already bound to a list of binders possibly *) + (* generating an alpha-renaming from unifying the new binders *) let bl' = Id.List.assoc var binderlists in - let unify_name alp na na' = - match na, na' with - | Anonymous, na' -> alp, na' - | na, Anonymous -> alp, na - | Name id, Name id' -> - if Id.equal id id' then alp, na' - else (fst alp,(id,id')::snd alp), na' in - let unify_pat alp p p' = - try fold_cases_pattern_eq unify_name alp p p' with Failure _ -> raise No_match in - let unify_term alp v v' = - match DAst.get v, DAst.get v' with - | GHole _, _ -> v' - | _, GHole _ -> v - | _, _ -> if glob_constr_eq (alpha_rename (snd alp) v) v' then v else raise No_match in - let unify_opt_term alp v v' = - match v, v' with - | Some t, Some t' -> Some (unify_term alp t t') - | (Some _ as x), None | None, (Some _ as x) -> x - | None, None -> None in - let unify_binding_kind bk bk' = if bk == bk' then bk' else raise No_match in - let unify_binder alp b b' = - let loc, loc' = CAst.(b.loc, b'.loc) in - match DAst.get b, DAst.get b' with - | GLocalAssum (na,bk,t), GLocalAssum (na',bk',t') -> - let alp, na = unify_name alp na na' in - alp, DAst.make ?loc @@ GLocalAssum (na, unify_binding_kind bk bk', unify_term alp t t') - | GLocalDef (na,bk,c,t), GLocalDef (na',bk',c',t') -> - let alp, na = unify_name alp na na' in - alp, DAst.make ?loc @@ GLocalDef (na, unify_binding_kind bk bk', unify_term alp c c', unify_opt_term alp t t') - | GLocalPattern ((p,ids),id,bk,t), GLocalPattern ((p',_),_,bk',t') -> - let alp, p = unify_pat alp p p' in - alp, DAst.make ?loc @@ GLocalPattern ((p,ids), id, unify_binding_kind bk bk', unify_term alp t t') - | _ -> raise No_match in - let rec unify alp bl bl' = - match bl, bl' with - | [], [] -> alp, [] - | b :: bl, b' :: bl' -> - let alp,b = unify_binder alp b b' in - let alp,bl = unify alp bl bl' in - alp, b :: bl - | _ -> raise No_match in - let alp, bl = unify alp bl bl' in - let sigma = (terms,onlybinders,termlists,Id.List.remove_assoc var binderlists) in + let alp, bl = unify_binders_upto alp bl bl' in + let sigma = (terms,termlists,binders,Id.List.remove_assoc var binderlists) in alp, add_bindinglist_env sigma var bl with Not_found -> alp, add_bindinglist_env sigma var bl -let bind_bindinglist_as_term_env alp (terms,onlybinders,termlists,binderlists) var cl = +let bind_bindinglist_as_termlist_env alp (terms,termlists,binders,binderlists) var cl = try + (* If already bound to a list of binders, unify the terms and binders *) let bl' = Id.List.assoc var binderlists in - let unify_id id na' = - match na' with - | Anonymous -> Name (rename_var (snd alp) id) - | Name id' -> - if Id.equal (rename_var (snd alp) id) id' then na' else raise No_match in - let unify_pat p p' = - if cases_pattern_eq (map_cases_pattern_name_left (Name.map (rename_var (snd alp))) p) p' then p' - else raise No_match in - let unify_term_binder c = DAst.(map (fun b' -> - match DAst.get c, b' with - | GVar id, GLocalAssum (na', bk', t') -> - GLocalAssum (unify_id id na', bk', t') - | _, GLocalPattern ((p',ids), id, bk', t') -> - let p = pat_binder_of_term c in - GLocalPattern ((unify_pat p p',ids), id, bk', t') - | _ -> raise No_match )) in - let rec unify cl bl' = - match cl, bl' with - | [], [] -> [] - | c :: cl, b' :: bl' -> - begin match DAst.get b' with - | GLocalDef ( _, _, _, t) -> unify cl bl' - | _ -> unify_term_binder c b' :: unify cl bl' - end - | _ -> raise No_match in - let bl = unify cl bl' in - let sigma = (terms,onlybinders,termlists,Id.List.remove_assoc var binderlists) in + let bl = unify_terms_binders alp cl bl' in + let sigma = (terms,termlists,binders,Id.List.remove_assoc var binderlists) in add_bindinglist_env sigma var bl with Not_found -> anomaly (str "There should be a binder list bindings this list of terms.") @@ -894,8 +970,10 @@ let match_opt f sigma t1 t2 = match (t1,t2) with | _ -> raise No_match let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with + | (na1,Name id2) when is_onlybinding_strict_meta id2 metas -> + raise No_match | (na1,Name id2) when is_onlybinding_meta id2 metas -> - bind_binding_env alp sigma id2 na1 + bind_binding_env alp sigma id2 [DAst.make (PatVar na1)] | (Name id1,Name id2) when is_term_meta id2 metas -> (* We let the non-binding occurrence define the rhs and hence reason up to *) (* alpha-conversion for the given occurrence of the name (see #4592)) *) @@ -907,54 +985,42 @@ let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with | (Anonymous,Anonymous) -> alp,sigma | _ -> raise No_match -let rec match_cases_pattern_binders metas acc pat1 pat2 = +let rec match_cases_pattern_binders allow_catchall metas (alp,sigma as acc) pat1 pat2 = match DAst.get pat1, DAst.get pat2 with + | PatVar _, PatVar (Name id2) when is_onlybinding_pattern_like_meta true id2 metas -> + bind_binding_env alp sigma id2 [pat1] + | _, PatVar (Name id2) when is_onlybinding_pattern_like_meta false id2 metas -> + bind_binding_env alp sigma id2 [pat1] | PatVar na1, PatVar na2 -> match_names metas acc na1 na2 + | _, PatVar Anonymous when allow_catchall -> acc | PatCstr (c1,patl1,na1), PatCstr (c2,patl2,na2) when eq_constructor c1 c2 && Int.equal (List.length patl1) (List.length patl2) -> - List.fold_left2 (match_cases_pattern_binders metas) - (match_names metas acc na1 na2) patl1 patl2 + List.fold_left2 (match_cases_pattern_binders false metas) + (match_names metas acc na1 na2) patl1 patl2 | _ -> raise No_match -let glue_letin_with_decls = true - -let rec match_iterated_binders islambda decls bi = DAst.(with_loc_val (fun ?loc -> function - | GLambda (na,bk,t,b) as b0 -> - begin match na, DAst.get b with - | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(ids,[cp],b))]) - when islambda && is_gvar p e && not (occur_glob_constr p b) -> - match_iterated_binders islambda ((DAst.make ?loc @@ GLocalPattern((cp,ids),p,bk,t))::decls) b - | _, _ when islambda -> - match_iterated_binders islambda ((DAst.make ?loc @@ GLocalAssum(na,bk,t))::decls) b - | _ -> (decls, DAst.make ?loc b0) - end - | GProd (na,bk,t,b) as b0 -> - begin match na, DAst.get b with - | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(ids,[cp],b))]) - when not islambda && is_gvar p e && not (occur_glob_constr p b) -> - match_iterated_binders islambda ((DAst.make ?loc @@ GLocalPattern((cp,ids),p,bk,t))::decls) b - | Name _, _ when not islambda -> - match_iterated_binders islambda ((DAst.make ?loc @@ GLocalAssum(na,bk,t))::decls) b - | _ -> (decls, DAst.make ?loc b0) - end - | GLetIn (na,c,t,b) when glue_letin_with_decls -> - match_iterated_binders islambda - ((DAst.make ?loc @@ GLocalDef (na,Explicit (*?*), c,t))::decls) b - | b -> (decls, DAst.make ?loc b) - )) bi - -let remove_sigma x (terms,onlybinders,termlists,binderlists) = - (Id.List.remove_assoc x terms,onlybinders,termlists,binderlists) +let remove_sigma x (terms,termlists,binders,binderlists) = + (Id.List.remove_assoc x terms,termlists,binders,binderlists) -let remove_bindinglist_sigma x (terms,onlybinders,termlists,binderlists) = - (terms,onlybinders,termlists,Id.List.remove_assoc x binderlists) +let remove_bindinglist_sigma x (terms,termlists,binders,binderlists) = + (terms,termlists,binders,Id.List.remove_assoc x binderlists) let add_ldots_var metas = (ldots_var,((None,[]),NtnTypeConstr))::metas let add_meta_bindinglist x metas = (x,((None,[]),NtnTypeBinderList))::metas -let match_binderlist_with_app match_fun alp metas sigma rest x y iter termin = - let rec aux sigma bl rest = +(* This tells if letins in the middle of binders should be included in + the sequence of binders *) +let glue_inner_letin_with_decls = true + +(* This tells if trailing letins (with no further proper binders) + should be included in sequence of binders *) +let glue_trailing_letin_with_decls = false + +exception OnlyTrailingLetIns + +let match_binderlist match_fun alp metas sigma rest x y iter termin revert = + let rec aux trailing_letins sigma bl rest = try let metas = add_ldots_var (add_meta_bindinglist y metas) in let (terms,_,_,binderlists as sigma) = match_fun alp metas sigma rest iter in @@ -963,16 +1029,32 @@ let match_binderlist_with_app match_fun alp metas sigma rest x y iter termin = match Id.List.assoc y binderlists with [b] -> b | _ ->assert false in let sigma = remove_bindinglist_sigma y (remove_sigma ldots_var sigma) in - aux sigma (b::bl) rest - with No_match when not (List.is_empty bl) -> - bl, rest, sigma in - let bl,rest,sigma = aux sigma [] rest in + (* In case y is bound not only to a binder but also to a term *) + let sigma = remove_sigma y sigma in + aux false sigma (b::bl) rest + with No_match -> + match DAst.get rest with + | GLetIn (na,c,t,rest') when glue_inner_letin_with_decls -> + let b = DAst.make ?loc:rest.CAst.loc @@ GLocalDef (na,Explicit (*?*), c,t) in + (* collect let-in *) + (try aux true sigma (b::bl) rest' + with OnlyTrailingLetIns + when not (trailing_letins && not glue_trailing_letin_with_decls) -> + (* renounce to take into account trailing let-ins *) + if not (List.is_empty bl) then bl, rest, sigma else raise No_match) + | _ -> + if trailing_letins && not glue_trailing_letin_with_decls then + (* Backtrack to when we tried to glue letins *) + raise OnlyTrailingLetIns; + if not (List.is_empty bl) then bl, rest, sigma else raise No_match in + let bl,rest,sigma = aux false sigma [] rest in + let bl = if revert then List.rev bl else bl in let alp,sigma = bind_bindinglist_env alp sigma x bl in match_fun alp metas sigma rest termin let add_meta_term x metas = (x,((None,[]),NtnTypeConstr))::metas -let match_termlist match_fun alp metas sigma rest x y iter termin lassoc = +let match_termlist match_fun alp metas sigma rest x y iter termin revert = let rec aux sigma acc rest = try let metas = add_ldots_var (add_meta_term y metas) in @@ -983,12 +1065,12 @@ let match_termlist match_fun alp metas sigma rest x y iter termin lassoc = aux sigma (t::acc) rest with No_match when not (List.is_empty acc) -> acc, match_fun metas sigma rest termin in - let l,(terms,onlybinders,termlists,binderlists as sigma) = aux sigma [] rest in - let l = if lassoc then l else List.rev l in + let l,(terms,termlists,binders,binderlists as sigma) = aux sigma [] rest in + let l = if revert then l else List.rev l in if is_bindinglist_meta x metas then (* This is a recursive pattern for both bindings and terms; it is *) (* registered for binders *) - bind_bindinglist_as_term_env alp sigma x l + bind_bindinglist_as_termlist_env alp sigma x l else bind_termlist_env alp sigma x l @@ -1023,72 +1105,19 @@ let rec match_ inner u alp metas sigma a1 a2 = match DAst.get a1, a2 with (* Matching notation variable *) | r1, NVar id2 when is_term_meta id2 metas -> bind_term_env alp sigma id2 a1 - | GVar id1, NVar id2 when is_onlybinding_meta id2 metas -> bind_binding_as_term_env alp sigma id2 id1 + | GVar _, NVar id2 when is_onlybinding_pattern_like_meta true id2 metas -> bind_binding_as_term_env alp sigma id2 a1 + | r1, NVar id2 when is_onlybinding_pattern_like_meta false id2 metas -> bind_binding_as_term_env alp sigma id2 a1 + | GVar _, NVar id2 when is_onlybinding_strict_meta id2 metas -> raise No_match + | GVar _, NVar id2 when is_onlybinding_meta id2 metas -> bind_binding_as_term_env alp sigma id2 a1 | r1, NVar id2 when is_bindinglist_meta id2 metas -> bind_term_env alp sigma id2 a1 (* Matching recursive notations for terms *) - | r1, NList (x,y,iter,termin,lassoc) -> - match_termlist (match_hd u alp) alp metas sigma a1 x y iter termin lassoc - - | GLambda (na1, bk, t1, b1), NBinderList (x,y,iter,termin) -> - begin match na1, DAst.get b1, iter with - (* "λ p, let 'cp = p in t" -> "λ 'cp, t" *) - | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(ids,[cp],b1))]), NLambda (Name _, _, _) - when is_gvar p e && not (occur_glob_constr p b1) -> - let (decls,b) = match_iterated_binders true [DAst.make ?loc @@ GLocalPattern((cp,ids),p,bk,t1)] b1 in - let alp,sigma = bind_bindinglist_env alp sigma x decls in - match_in u alp metas sigma b termin - (* Matching recursive notations for binders: ad hoc cases supporting let-in *) - | _, _, NLambda (Name _,_,_) -> - let (decls,b) = match_iterated_binders true [DAst.make ?loc @@ GLocalAssum (na1,bk,t1)] b1 in - (* TODO: address the possibility that termin is a Lambda itself *) - let alp,sigma = bind_bindinglist_env alp sigma x decls in - match_in u alp metas sigma b termin - (* Matching recursive notations for binders: general case *) - | _, _, _ -> - match_binderlist_with_app (match_hd u) alp metas sigma a1 x y iter termin - end - - | GProd (na1, bk, t1, b1), NBinderList (x,y,iter,termin) -> - (* "∀ p, let 'cp = p in t" -> "∀ 'cp, t" *) - begin match na1, DAst.get b1, iter, termin with - | Name p, GCases (LetPatternStyle,None,[(e, _)],[(_,(ids,[cp],b1))]), NProd (Name _,_,_), NVar _ - when is_gvar p e && not (occur_glob_constr p b1) -> - let (decls,b) = match_iterated_binders true [DAst.make ?loc @@ GLocalPattern ((cp,ids),p,bk,t1)] b1 in - let alp,sigma = bind_bindinglist_env alp sigma x decls in - match_in u alp metas sigma b termin - | _, _, NProd (Name _,_,_), _ when na1 != Anonymous -> - let (decls,b) = match_iterated_binders false [DAst.make ?loc @@ GLocalAssum (na1,bk,t1)] b1 in - (* TODO: address the possibility that termin is a Prod itself *) - let alp,sigma = bind_bindinglist_env alp sigma x decls in - match_in u alp metas sigma b termin - (* Matching recursive notations for binders: general case *) - | _, _, _, _ -> - match_binderlist_with_app (match_hd u) alp metas sigma a1 x y iter termin - end + | r1, NList (x,y,iter,termin,revert) -> + match_termlist (match_hd u alp) alp metas sigma a1 x y iter termin revert (* Matching recursive notations for binders: general case *) - | _r, NBinderList (x,y,iter,termin) -> - match_binderlist_with_app (match_hd u) alp metas sigma a1 x y iter termin - - (* Matching individual binders as part of a recursive pattern *) - | GLambda (na1, bk, t1, b1), NLambda (na2, t2, b2) -> - begin match na1, DAst.get b1, na2 with - | Name p, GCases (LetPatternStyle,None,[(e,_)],[(_,(ids,[cp],b1))]), Name id - when is_gvar p e && is_bindinglist_meta id metas && not (occur_glob_constr p b1) -> - let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalPattern ((cp,ids),p,bk,t1)] in - match_in u alp metas sigma b1 b2 - | _, _, Name id when is_bindinglist_meta id metas -> - let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalAssum (na1,bk,t1)] in - match_in u alp metas sigma b1 b2 - | _ -> - match_binders u alp metas na1 na2 (match_in u alp metas sigma t1 t2) b1 b2 - end - - | GProd (na,bk,t,b1), NProd (Name id,_,b2) - when is_bindinglist_meta id metas && na != Anonymous -> - let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalAssum (na,bk,t)] in - match_in u alp metas sigma b1 b2 + | _r, NBinderList (x,y,iter,termin,revert) -> + match_binderlist (match_hd u) alp metas sigma a1 x y iter termin revert (* Matching compositionally *) | GVar id1, NVar id2 when alpha_var id1 id2 (fst alp) -> sigma @@ -1104,8 +1133,10 @@ let rec match_ inner u alp metas sigma a1 a2 = let may_use_eta = does_not_come_from_already_eta_expanded_var f1 in List.fold_left2 (match_ may_use_eta u alp metas) (match_in u alp metas sigma f1 f2) l1 l2 - | GProd (na1,_,t1,b1), NProd (na2,t2,b2) -> - match_binders u alp metas na1 na2 (match_in u alp metas sigma t1 t2) b1 b2 + | GLambda (na1,bk1,t1,b1), NLambda (na2,t2,b2) -> + match_extended_binders false u alp metas na1 na2 bk1 t1 (match_in u alp metas sigma t1 t2) b1 b2 + | GProd (na1,bk1,t1,b1), NProd (na2,t2,b2) -> + match_extended_binders true u alp metas na1 na2 bk1 t1 (match_in u alp metas sigma t1 t2) b1 b2 | GLetIn (na1,b1,_,c1), NLetIn (na2,b2,None,c2) | GLetIn (na1,b1,None,c1), NLetIn (na2,b2,_,c2) -> match_binders u alp metas na1 na2 (match_in u alp metas sigma b1 b2) c1 c2 @@ -1113,9 +1144,7 @@ let rec match_ inner u alp metas sigma a1 a2 = match_binders u alp metas na1 na2 (match_in u alp metas (match_in u alp metas sigma b1 b2) t1 t2) c1 c2 | GCases (sty1,rtno1,tml1,eqnl1), NCases (sty2,rtno2,tml2,eqnl2) - when sty1 == sty2 - && Int.equal (List.length tml1) (List.length tml2) - && Int.equal (List.length eqnl1) (List.length eqnl2) -> + when sty1 == sty2 && Int.equal (List.length tml1) (List.length tml2) -> let rtno1' = abstract_return_type_context_glob_constr tml1 rtno1 in let rtno2' = abstract_return_type_context_notation_constr tml2 rtno2 in let sigma = @@ -1125,7 +1154,14 @@ let rec match_ inner u alp metas sigma a1 a2 = let sigma = List.fold_left2 (fun s (tm1,_) (tm2,_) -> match_in u alp metas s tm1 tm2) sigma tml1 tml2 in - List.fold_left2 (match_equations u alp metas) sigma eqnl1 eqnl2 + (* Try two different strategies for matching clauses *) + (try + List.fold_left2_set No_match (match_equations u alp metas) sigma eqnl1 eqnl2 + with + No_match -> + List.fold_left2_set No_match (match_disjunctive_equations u alp metas) sigma + (Detyping.factorize_eqns eqnl1) + (List.map (fun (patl,rhs) -> ([patl],rhs)) eqnl2)) | GLetTuple (nal1,(na1,to1),b1,c1), NLetTuple (nal2,(na2,to2),b2,c2) when Int.equal (List.length nal1) (List.length nal2) -> let sigma = match_opt (match_binders u alp metas na1 na2) sigma to1 to2 in @@ -1191,44 +1227,83 @@ and match_in u = match_ true u and match_hd u = match_ false u and match_binders u alp metas na1 na2 sigma b1 b2 = + (* Match binders which cannot be substituted by a pattern *) let (alp,sigma) = match_names metas (alp,sigma) na1 na2 in match_in u alp metas sigma b1 b2 -and match_equations u alp metas sigma (_,(_,patl1,rhs1)) (patl2,rhs2) = +and match_extended_binders ?loc isprod u alp metas na1 na2 bk t sigma b1 b2 = + (* Match binders which can be substituted by a pattern *) + let store, get = set_temporary_memory () in + match na1, DAst.get b1, na2 with + (* Matching individual binders as part of a recursive pattern *) + | Name p, GCases (LetPatternStyle,None,[(e,_)],(_::_ as eqns)), Name id + when is_gvar p e && is_bindinglist_meta id metas && List.length (store (Detyping.factorize_eqns eqns)) = 1 -> + (match get () with + | [(_,(ids,disj_of_patl,b1))] -> + let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in + let disjpat = if occur_glob_constr p b1 then List.map (set_pat_alias p) disjpat else disjpat in + let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalPattern ((disjpat,ids),p,bk,t)] in + match_in u alp metas sigma b1 b2 + | _ -> assert false) + | Name p, GCases (LetPatternStyle,None,[(e,_)],(_::_ as eqns)), Name id + when is_gvar p e && is_onlybinding_pattern_like_meta false id metas && List.length (store (Detyping.factorize_eqns eqns)) = 1 -> + (match get () with + | [(_,(ids,disj_of_patl,b1))] -> + let disjpat = List.map (function [pat] -> pat | _ -> assert false) disj_of_patl in + let disjpat = if occur_glob_constr p b1 then List.map (set_pat_alias p) disjpat else disjpat in + let alp,sigma = bind_binding_env alp sigma id disjpat in + match_in u alp metas sigma b1 b2 + | _ -> assert false) + | _, _, Name id when is_bindinglist_meta id metas && (not isprod || na1 != Anonymous)-> + let alp,sigma = bind_bindinglist_env alp sigma id [DAst.make ?loc @@ GLocalAssum (na1,bk,t)] in + match_in u alp metas sigma b1 b2 + | _, _, _ -> + let (alp,sigma) = match_names metas (alp,sigma) na1 na2 in + match_in u alp metas sigma b1 b2 + +and match_equations u alp metas sigma (_,(ids,patl1,rhs1)) (patl2,rhs2) rest1 rest2 = (* patl1 and patl2 have the same length because they respectively correspond to some tml1 and tml2 that have the same length *) + let allow_catchall = (rest2 = [] && ids = []) in let (alp,sigma) = - List.fold_left2 (match_cases_pattern_binders metas) + List.fold_left2 (match_cases_pattern_binders allow_catchall metas) (alp,sigma) patl1 patl2 in match_in u alp metas sigma rhs1 rhs2 -let term_of_binder bi = DAst.make @@ match bi with - | Name id -> GVar id - | Anonymous -> GHole (Evar_kinds.InternalHole,Misctypes.IntroAnonymous,None) +and match_disjunctive_equations u alp metas sigma (_,(ids,disjpatl1,rhs1)) (disjpatl2,rhs2) _ _ = + (* patl1 and patl2 have the same length because they respectively + correspond to some tml1 and tml2 that have the same length *) + let (alp,sigma) = + List.fold_left2_set No_match + (fun alp_sigma patl1 patl2 _ _ -> + List.fold_left2 (match_cases_pattern_binders false metas) alp_sigma patl1 patl2) + (alp,sigma) disjpatl1 disjpatl2 in + match_in u alp metas sigma rhs1 rhs2 let match_notation_constr u c (metas,pat) = - let terms,binders,termlists,binderlists = + let terms,termlists,binders,binderlists = match_ false u ([],[]) metas ([],[],[],[]) c pat in - (* Reorder canonically the substitution *) - let find_binder x = - try term_of_binder (Id.List.assoc x binders) - with Not_found -> - (* Happens for binders bound to Anonymous *) - (* Find a better way to propagate Anonymous... *) - DAst.make @@GVar x in - List.fold_right (fun (x,(scl,typ)) (terms',termlists',binders') -> + (* Turning substitution based on binding/constr distinction into a + substitution based on entry productions *) + List.fold_right (fun (x,(scl,typ)) (terms',termlists',binders',binderlists') -> match typ with | NtnTypeConstr -> let term = try Id.List.assoc x terms with Not_found -> raise No_match in - ((term, scl)::terms',termlists',binders') - | NtnTypeOnlyBinder -> - ((find_binder x, scl)::terms',termlists',binders') + ((term, scl)::terms',termlists',binders',binderlists') + | NtnTypeBinder (NtnBinderParsedAsConstr _) -> + (match Id.List.assoc x binders with + | [pat] -> + let v = glob_constr_of_cases_pattern pat in + ((v,scl)::terms',termlists',binders',binderlists') + | _ -> raise No_match) + | NtnTypeBinder (NtnParsedAsIdent | NtnParsedAsPattern _) -> + (terms',termlists',(Id.List.assoc x binders,scl)::binders',binderlists') | NtnTypeConstrList -> - (terms',(Id.List.assoc x termlists,scl)::termlists',binders') + (terms',(Id.List.assoc x termlists,scl)::termlists',binders',binderlists') | NtnTypeBinderList -> let bl = try Id.List.assoc x binderlists with Not_found -> raise No_match in - (terms',termlists',(bl, scl)::binders')) - metas ([],[],[]) + (terms',termlists',binders',(bl, scl)::binderlists')) + metas ([],[],[],[]) (* Matching cases pattern *) @@ -1240,7 +1315,7 @@ let bind_env_cases_pattern (terms,x,termlists,y as sigma) var v = (* TODO: handle the case of multiple occs in different scopes *) (var,v)::terms,x,termlists,y -let match_cases_pattern_list match_fun metas sigma rest x y iter termin lassoc = +let match_cases_pattern_list match_fun metas sigma rest x y iter termin revert = let rec aux sigma acc rest = try let metas = add_ldots_var (add_meta_term y metas) in @@ -1251,10 +1326,10 @@ let match_cases_pattern_list match_fun metas sigma rest x y iter termin lassoc = aux sigma (t::acc) rest with No_match when not (List.is_empty acc) -> acc, match_fun metas sigma rest termin in - let l,(terms,onlybinders,termlists,binderlists as sigma) = aux sigma [] rest in - (terms,onlybinders,(x,if lassoc then l else List.rev l)::termlists, binderlists) + let l,(terms,termlists,binders,binderlists as sigma) = aux sigma [] rest in + (terms,(x,if revert then l else List.rev l)::termlists,binders,binderlists) -let rec match_cases_pattern metas (terms,(),termlists,() as sigma) a1 a2 = +let rec match_cases_pattern metas (terms,termlists,(),() as sigma) a1 a2 = match DAst.get a1, a2 with | r1, NVar id2 when Id.List.mem_assoc id2 metas -> (bind_env_cases_pattern sigma id2 a1),(0,[]) | PatVar Anonymous, NHole _ -> sigma,(0,[]) @@ -1270,10 +1345,10 @@ let rec match_cases_pattern metas (terms,(),termlists,() as sigma) a1 a2 = raise No_match else let l1',more_args = Util.List.chop le2 l1 in - (List.fold_left2 (match_cases_pattern_no_more_args metas) sigma l1' l2),(le2,more_args) - | r1, NList (x,y,iter,termin,lassoc) -> + (List.fold_left2 (match_cases_pattern_no_more_args metas) sigma l1' l2),(le2,more_args) + | r1, NList (x,y,iter,termin,revert) -> (match_cases_pattern_list (match_cases_pattern_no_more_args) - metas (terms,(),termlists,()) a1 x y iter termin lassoc),(0,[]) + metas (terms,termlists,(),()) a1 x y iter termin revert),(0,[]) | _ -> raise No_match and match_cases_pattern_no_more_args metas sigma a1 a2 = @@ -1300,15 +1375,15 @@ let reorder_canonically_substitution terms termlists metas = List.fold_right (fun (x,(scl,typ)) (terms',termlists') -> match typ with | NtnTypeConstr -> ((Id.List.assoc x terms, scl)::terms',termlists') - | NtnTypeOnlyBinder -> assert false + | NtnTypeBinder _ -> assert false | NtnTypeConstrList -> (terms',(Id.List.assoc x termlists,scl)::termlists') | NtnTypeBinderList -> assert false) metas ([],[]) let match_notation_constr_cases_pattern c (metas,pat) = - let (terms,(),termlists,()),more_args = match_cases_pattern metas ([],(),[],()) c pat in + let (terms,termlists,(),()),more_args = match_cases_pattern metas ([],[],(),()) c pat in reorder_canonically_substitution terms termlists metas, more_args let match_notation_constr_ind_pattern ind args (metas,pat) = - let (terms,(),termlists,()),more_args = match_ind_pattern metas ([],(),[],()) ind args pat in + let (terms,termlists,(),()),more_args = match_ind_pattern metas ([],[],(),()) ind args pat in reorder_canonically_substitution terms termlists metas, more_args |