summaryrefslogtreecommitdiff
path: root/interp/topconstr.ml
diff options
context:
space:
mode:
Diffstat (limited to 'interp/topconstr.ml')
-rw-r--r--interp/topconstr.ml588
1 files changed, 384 insertions, 204 deletions
diff --git a/interp/topconstr.ml b/interp/topconstr.ml
index 5911f667..b8a90088 100644
--- a/interp/topconstr.ml
+++ b/interp/topconstr.ml
@@ -6,7 +6,7 @@
(* * GNU Lesser General Public License Version 2.1 *)
(************************************************************************)
-(* $Id$ *)
+(* $Id: topconstr.ml 13357 2010-07-29 22:59:55Z herbelin $ *)
(*i*)
open Pp
@@ -36,6 +36,7 @@ type aconstr =
(* Part only in rawconstr *)
| ALambda of name * aconstr * aconstr
| AProd of name * aconstr * aconstr
+ | ABinderList of identifier * identifier * aconstr * aconstr
| ALetIn of name * aconstr * aconstr
| ACases of case_style * aconstr option *
(aconstr * (name * (inductive * int * name list) option)) list *
@@ -50,6 +51,21 @@ type aconstr =
| APatVar of patvar
| ACast of aconstr * aconstr cast_type
+type scope_name = string
+
+type tmp_scope_name = scope_name
+
+type subscopes = tmp_scope_name option * scope_name list
+
+type notation_var_instance_type =
+ | NtnTypeConstr | NtnTypeConstrList | NtnTypeBinderList
+
+type notation_var_internalization_type =
+ | NtnInternTypeConstr | NtnInternTypeBinder | NtnInternTypeIdent
+
+type interpretation =
+ (identifier * (subscopes * notation_var_instance_type)) list * aconstr
+
(**********************************************************************)
(* Re-interpret a notation as a rawconstr, taking care of binders *)
@@ -69,6 +85,16 @@ let rec cases_pattern_fold_map loc g e = function
let rec subst_rawvars l = function
| RVar (_,id) as r -> (try List.assoc id l with Not_found -> r)
+ | RProd (loc,Name id,bk,t,c) ->
+ let id =
+ try match List.assoc id l with RVar(_,id') -> id' | _ -> id
+ with Not_found -> id in
+ RProd (loc,Name id,bk,subst_rawvars l t,subst_rawvars l c)
+ | RLambda (loc,Name id,bk,t,c) ->
+ let id =
+ try match List.assoc id l with RVar(_,id') -> id' | _ -> id
+ with Not_found -> id in
+ RLambda (loc,Name id,bk,subst_rawvars l t,subst_rawvars l c)
| r -> map_rawconstr (subst_rawvars l) r (* assume: id is not binding *)
let ldots_var = id_of_string ".."
@@ -82,6 +108,12 @@ let rawconstr_of_aconstr_with_binders loc g f e = function
let inner = RApp (loc,RVar (loc,ldots_var),[subst_rawvars innerl it]) in
let outerl = (ldots_var,inner)::(if swap then [x,RVar(loc,y)] else []) in
subst_rawvars outerl it
+ | ABinderList (x,y,iter,tail) ->
+ let t = f e tail in let it = f e iter in
+ let innerl = [(ldots_var,t);(x,RVar(loc,y))] in
+ let inner = RApp (loc,RVar (loc,ldots_var),[subst_rawvars innerl it]) in
+ let outerl = [(ldots_var,inner)] in
+ subst_rawvars outerl it
| ALambda (na,ty,c) ->
let e,na = g e na in RLambda (loc,na,Explicit,f e ty,f e c)
| AProd (na,ty,c) ->
@@ -134,72 +166,135 @@ let rec rawconstr_of_aconstr loc x =
(****************************************************************************)
(* Translating a rawconstr into a notation, interpreting recursive patterns *)
-let add_name r = function
- | Anonymous -> ()
- | Name id -> r := id :: !r
+let add_id r id = r := (id :: pi1 !r, pi2 !r, pi3 !r)
+let add_name r = function Anonymous -> () | Name id -> add_id r id
-let has_ldots =
- List.exists
- (function RApp (_,RVar(_,v),_) when v = ldots_var -> true | _ -> false)
-
-let compare_rawconstr f t1 t2 = match t1,t2 with
- | RRef (_,r1), RRef (_,r2) -> eq_gr r1 r2
- | RVar (_,v1), RVar (_,v2) -> v1 = v2
- | RApp (_,f1,l1), RApp (_,f2,l2) -> f f1 f2 & List.for_all2 f l1 l2
- | RLambda (_,na1,bk1,ty1,c1), RLambda (_,na2,bk2,ty2,c2) when na1 = na2 && bk1 = bk2 ->
- f ty1 ty2 & f c1 c2
+let split_at_recursive_part c =
+ let sub = ref None in
+ let rec aux = function
+ | RApp (loc0,RVar(loc,v),c::l) when v = ldots_var ->
+ if !sub <> None then
+ (* Not narrowed enough to find only one recursive part *)
+ raise Not_found
+ else
+ (sub := Some c;
+ if l = [] then RVar (loc,ldots_var)
+ else RApp (loc0,RVar (loc,ldots_var),l))
+ | c -> map_rawconstr aux c in
+ let outer_iterator = aux c in
+ match !sub with
+ | None -> (* No recursive pattern found *) raise Not_found
+ | Some c ->
+ match outer_iterator with
+ | RVar (_,v) when v = ldots_var -> (* Not enough context *) raise Not_found
+ | _ -> outer_iterator, c
+
+let on_true_do b f c = if b then (f c; b) else b
+
+let compare_rawconstr f add t1 t2 = match t1,t2 with
+ | RRef (_,r1), RRef (_,r2) -> eq_gr r1 r2
+ | RVar (_,v1), RVar (_,v2) -> on_true_do (v1 = v2) add (Name v1)
+ | RApp (_,f1,l1), RApp (_,f2,l2) -> f f1 f2 & list_for_all2eq f l1 l2
+ | RLambda (_,na1,bk1,ty1,c1), RLambda (_,na2,bk2,ty2,c2) when na1 = na2 && bk1 = bk2 -> on_true_do (f ty1 ty2 & f c1 c2) add na1
| RProd (_,na1,bk1,ty1,c1), RProd (_,na2,bk2,ty2,c2) when na1 = na2 && bk1 = bk2 ->
- f ty1 ty2 & f c1 c2
+ on_true_do (f ty1 ty2 & f c1 c2) add na1
| RHole _, RHole _ -> true
| RSort (_,s1), RSort (_,s2) -> s1 = s2
- | (RLetIn _ | RCases _ | RRec _ | RDynamic _
+ | RLetIn (_,na1,b1,c1), RLetIn (_,na2,b2,c2) when na1 = na2 ->
+ on_true_do (f b1 b2 & f c1 c2) add na1
+ | (RCases _ | RRec _ | RDynamic _
| RPatVar _ | REvar _ | RLetTuple _ | RIf _ | RCast _),_
- | _,(RLetIn _ | RCases _ | RRec _ | RDynamic _
+ | _,(RCases _ | RRec _ | RDynamic _
| RPatVar _ | REvar _ | RLetTuple _ | RIf _ | RCast _)
-> error "Unsupported construction in recursive notations."
- | (RRef _ | RVar _ | RApp _ | RLambda _ | RProd _ | RHole _ | RSort _), _
+ | (RRef _ | RVar _ | RApp _ | RLambda _ | RProd _
+ | RHole _ | RSort _ | RLetIn _), _
-> false
-let rec eq_rawconstr t1 t2 = compare_rawconstr eq_rawconstr t1 t2
+let rec eq_rawconstr t1 t2 = compare_rawconstr eq_rawconstr (fun _ -> ()) t1 t2
+
+let subtract_loc loc1 loc2 = make_loc (fst (unloc loc1),fst (unloc loc2)-1)
-let discriminate_patterns foundvars nl l1 l2 =
+let check_is_hole id = function RHole _ -> () | t ->
+ user_err_loc (loc_of_rawconstr t,"",
+ strbrk "In recursive notation with binders, " ++ pr_id id ++
+ strbrk " is expected to come without type.")
+
+let compare_recursive_parts found f (iterator,subc) =
let diff = ref None in
- let rec aux n c1 c2 = match c1,c2 with
- | RVar (_,v1), RVar (_,v2) when v1<>v2 ->
- if !diff = None then (diff := Some (v1,v2,(n>=nl)); true)
- else
- !diff = Some (v1,v2,(n>=nl)) or !diff = Some (v2,v1,(n<nl))
- or (error
- "Both ends of the recursive pattern differ in more than one place")
- | _ -> compare_rawconstr (aux (n+1)) c1 c2 in
- let l = list_map2_i aux 0 l1 l2 in
- if not (List.for_all ((=) true) l) then
- error "Both ends of the recursive pattern differ.";
- match !diff with
- | None -> error "Both ends of the recursive pattern are the same."
- | Some (x,y,_ as discr) ->
- List.iter (fun id ->
- if List.mem id !foundvars
- then errorlabstrm "" (strbrk "Variables used in the recursive part of a pattern are not allowed to occur outside of the recursive part.");
- foundvars := id::!foundvars) [x;y];
- discr
+ let terminator = ref None in
+ let rec aux c1 c2 = match c1,c2 with
+ | RVar(_,v), term when v = ldots_var ->
+ (* We found the pattern *)
+ assert (!terminator = None); terminator := Some term;
+ true
+ | RApp (_,RVar(_,v),l1), RApp (_,term,l2) when v = ldots_var ->
+ (* We found the pattern, but there are extra arguments *)
+ (* (this allows e.g. alternative (recursive) notation of application) *)
+ assert (!terminator = None); terminator := Some term;
+ list_for_all2eq aux l1 l2
+ | RVar (_,x), RVar (_,y) when x<>y ->
+ (* We found the position where it differs *)
+ let lassoc = (!terminator <> None) in
+ let x,y = if lassoc then y,x else x,y in
+ !diff = None && (diff := Some (x,y,Some lassoc); true)
+ | RLambda (_,Name x,_,t_x,c), RLambda (_,Name y,_,t_y,term)
+ | RProd (_,Name x,_,t_x,c), RProd (_,Name y,_,t_y,term) ->
+ (* We found a binding position where it differs *)
+ check_is_hole y t_x;
+ check_is_hole y t_y;
+ !diff = None && (diff := Some (x,y,None); aux c term)
+ | _ ->
+ compare_rawconstr aux (add_name found) c1 c2 in
+ if aux iterator subc then
+ match !diff with
+ | None ->
+ let loc1 = loc_of_rawconstr iterator in
+ let loc2 = loc_of_rawconstr (Option.get !terminator) in
+ (* Here, we would need a loc made of several parts ... *)
+ user_err_loc (subtract_loc loc1 loc2,"",
+ str "Both ends of the recursive pattern are the same.")
+ | Some (x,y,Some lassoc) ->
+ let newfound = (pi1 !found, (x,y) :: pi2 !found, pi3 !found) in
+ let iterator =
+ f (if lassoc then subst_rawvars [y,RVar(dummy_loc,x)] iterator
+ else iterator) in
+ (* found have been collected by compare_constr *)
+ found := newfound;
+ AList (x,y,iterator,f (Option.get !terminator),lassoc)
+ | Some (x,y,None) ->
+ let newfound = (pi1 !found, pi2 !found, (x,y) :: pi3 !found) in
+ let iterator = f iterator in
+ (* found have been collected by compare_constr *)
+ found := newfound;
+ ABinderList (x,y,iterator,f (Option.get !terminator))
+ else
+ raise Not_found
let aconstr_and_vars_of_rawconstr a =
- let found = ref [] in
- let rec aux = function
- | RVar (_,id) -> found := id::!found; AVar id
- | RApp (_,f,args) when has_ldots args -> make_aconstr_list f args
- | RApp (_,RVar (_,f),[RApp (_,t,[c]);d]) when f = ldots_var ->
- (* Special case for alternative (recursive) notation of application *)
- let x,y,lassoc = discriminate_patterns found 0 [c] [d] in
- found := ldots_var :: !found; assert lassoc;
- AList (x,y,AApp (AVar ldots_var,[AVar x]),aux t,lassoc)
+ let found = ref ([],[],[]) in
+ let rec aux c =
+ let keepfound = !found in
+ (* n^2 complexity but small and done only once per notation *)
+ try compare_recursive_parts found aux' (split_at_recursive_part c)
+ with Not_found ->
+ found := keepfound;
+ match c with
+ | RApp (_,RVar (loc,f),[c]) when f = ldots_var ->
+ (* Fall on the second part of the recursive pattern w/o having
+ found the first part *)
+ user_err_loc (loc,"",
+ str "Cannot find where the recursive pattern starts.")
+ | c ->
+ aux' c
+ and aux' = function
+ | RVar (_,id) -> add_id found id; AVar id
| RApp (_,g,args) -> AApp (aux g, List.map aux args)
| RLambda (_,na,bk,ty,c) -> add_name found na; ALambda (na,aux ty,aux c)
| RProd (_,na,bk,ty,c) -> add_name found na; AProd (na,aux ty,aux c)
| RLetIn (_,na,b,c) -> add_name found na; ALetIn (na,aux b,aux c)
| RCases (_,sty,rtntypopt,tml,eqnl) ->
- let f (_,idl,pat,rhs) = found := idl@(!found); (pat,aux rhs) in
+ let f (_,idl,pat,rhs) = List.iter (add_id found) idl; (pat,aux rhs) in
ACases (sty,Option.map aux rtntypopt,
List.map (fun (tm,(na,x)) ->
add_name found na;
@@ -215,7 +310,7 @@ let aconstr_and_vars_of_rawconstr a =
add_name found na;
AIf (aux c,(na,Option.map aux po),aux b1,aux b2)
| RRec (_,fk,idl,dll,tl,bl) ->
- Array.iter (fun id -> found := id::!found) idl;
+ Array.iter (add_id found) idl;
let dll = Array.map (List.map (fun (na,bk,oc,b) ->
if bk <> Explicit then
error "Binders marked as implicit not allowed in notations.";
@@ -231,51 +326,61 @@ let aconstr_and_vars_of_rawconstr a =
| RDynamic _ | REvar _ ->
error "Existential variables not allowed in notations."
- (* Recognizing recursive notations *)
- and terminator_of_pat f1 ll1 lr1 = function
- | RApp (loc,f2,l2) ->
- if not (eq_rawconstr f1 f2) then errorlabstrm ""
- (strbrk "Cannot recognize the same head to both ends of the recursive pattern.");
- let nl = List.length ll1 in
- let nr = List.length lr1 in
- if List.length l2 <> nl + nr + 1 then
- error "Both ends of the recursive pattern have different lengths.";
- let ll2,l2' = list_chop nl l2 in
- let t = List.hd l2' and lr2 = List.tl l2' in
- let x,y,order = discriminate_patterns found nl (ll1@lr1) (ll2@lr2) in
- let iter =
- if order then RApp (loc,f2,ll2@RVar (loc,ldots_var)::lr2)
- else RApp (loc,f1,ll1@RVar (loc,ldots_var)::lr1) in
- (if order then y else x),(if order then x else y), aux iter, aux t, order
- | _ -> error "One end of the recursive pattern is not an application."
-
- and make_aconstr_list f args =
- let rec find_patterns acc = function
- | RApp(_,RVar (_,a),[c]) :: l when a = ldots_var ->
- (* We've found the recursive part *)
- let x,y,iter,term,lassoc = terminator_of_pat f (List.rev acc) l c in
- AList (x,y,iter,term,lassoc)
- | a::l -> find_patterns (a::acc) l
- | [] -> error "Ill-formed recursive notation."
- in find_patterns [] args
-
in
let t = aux a in
(* Side effect *)
t, !found
-let aconstr_of_rawconstr vars a =
- let a,foundvars = aconstr_and_vars_of_rawconstr a in
- let check_type x =
- if not (List.mem x foundvars) then
- error ((string_of_id x)^" is unbound in the right-hand-side.") in
- List.iter check_type vars;
+let rec list_rev_mem_assoc x = function
+ | [] -> false
+ | (_,x')::l -> x = x' || list_rev_mem_assoc x l
+
+let check_variables vars recvars (found,foundrec,foundrecbinding) =
+ let useless_vars = List.map snd recvars in
+ let vars = List.filter (fun (y,_) -> not (List.mem y useless_vars)) vars in
+ let check_recvar x =
+ if List.mem x found then
+ errorlabstrm "" (pr_id x ++
+ strbrk " should only be used in the recursive part of a pattern.") in
+ List.iter (fun (x,y) -> check_recvar x; check_recvar y)
+ (foundrec@foundrecbinding);
+ let check_bound x =
+ if not (List.mem x found) then
+ if List.mem_assoc x foundrec or List.mem_assoc x foundrecbinding
+ or list_rev_mem_assoc x foundrec or list_rev_mem_assoc x foundrecbinding
+ then
+ error ((string_of_id x)^" should not be bound in a recursive pattern of the right-hand side.")
+ else
+ error ((string_of_id x)^" is unbound in the right-hand side.") in
+ let check_pair s x y where =
+ if not (List.mem (x,y) where) then
+ errorlabstrm "" (strbrk "in the right-hand side, " ++ pr_id x ++
+ str " and " ++ pr_id y ++ strbrk " should appear in " ++ str s ++
+ str " position as part of a recursive pattern.") in
+ let check_type (x,typ) =
+ match typ with
+ | NtnInternTypeConstr ->
+ begin
+ try check_pair "term" x (List.assoc x recvars) foundrec
+ with Not_found -> check_bound x
+ end
+ | NtnInternTypeBinder ->
+ begin
+ try check_pair "binding" x (List.assoc x recvars) foundrecbinding
+ with Not_found -> check_bound x
+ end
+ | NtnInternTypeIdent -> check_bound x in
+ List.iter check_type vars
+
+let aconstr_of_rawconstr vars recvars a =
+ let a,found = aconstr_and_vars_of_rawconstr a in
+ check_variables vars recvars found;
a
(* Substitution of kernel names, avoiding a list of bound identifiers *)
let aconstr_of_constr avoiding t =
- aconstr_of_rawconstr [] (Detyping.detype false avoiding [] t)
+ aconstr_of_rawconstr [] [] (Detyping.detype false avoiding [] t)
let rec subst_pat subst pat =
match pat with
@@ -319,6 +424,12 @@ let rec subst_aconstr subst bound raw =
if r1' == r1 && r2' == r2 then raw else
AProd (n,r1',r2')
+ | ABinderList (id1,id2,r1,r2) ->
+ let r1' = subst_aconstr subst bound r1
+ and r2' = subst_aconstr subst bound r2 in
+ if r1' == r1 && r2' == r2 then raw else
+ ABinderList (id1,id2,r1',r2')
+
| ALetIn (n,r1,r2) ->
let r1' = subst_aconstr subst bound r1
and r2' = subst_aconstr subst bound r2 in
@@ -396,7 +507,7 @@ let rec subst_aconstr subst bound raw =
ACast (r1',CastCoerce)
let subst_interpretation subst (metas,pat) =
- let bound = List.map fst (fst metas @ snd metas) in
+ let bound = List.map fst metas in
(metas,subst_aconstr subst bound pat)
let encode_list_value l = RApp (dummy_loc,RVar (dummy_loc,ldots_var),l)
@@ -434,7 +545,7 @@ let rec alpha_var id1 id2 = function
let alpha_eq_val (x,y) = x = y
-let bind_env alp (sigma,sigmalist as fullsigma) var v =
+let bind_env alp (sigma,sigmalist,sigmabinders as fullsigma) var v =
try
let vvar = List.assoc var sigma in
if alpha_eq_val (v,vvar) then fullsigma
@@ -443,7 +554,10 @@ let bind_env alp (sigma,sigmalist as fullsigma) var v =
(* Check that no capture of binding variables occur *)
if List.exists (fun (id,_) ->occur_rawconstr id v) alp then raise No_match;
(* TODO: handle the case of multiple occs in different scopes *)
- ((var,v)::sigma,sigmalist)
+ ((var,v)::sigma,sigmalist,sigmabinders)
+
+let bind_binder (sigma,sigmalist,sigmabinders) x bl =
+ (sigma,sigmalist,(x,List.rev bl)::sigmabinders)
let match_fix_kind fk1 fk2 =
match (fk1,fk2) with
@@ -458,13 +572,9 @@ let match_opt f sigma t1 t2 = match (t1,t2) with
| Some t1, Some t2 -> f sigma t1 t2
| _ -> raise No_match
-let rawconstr_of_name = function
- | Anonymous -> RHole (dummy_loc,Evd.InternalHole)
- | Name id -> RVar (dummy_loc,id)
-
let match_names metas (alp,sigma) na1 na2 = match (na1,na2) with
- | (na,Name id2) when List.mem id2 metas ->
- alp, bind_env alp sigma id2 (rawconstr_of_name na)
+ | (Name id1,Name id2) when List.mem id2 (fst metas) ->
+ alp, bind_env alp sigma id2 (RVar (dummy_loc,id1))
| (Name id1,Name id2) -> (id1,id2)::alp,sigma
| (Anonymous,Anonymous) -> alp,sigma
| _ -> raise No_match
@@ -482,33 +592,80 @@ let adjust_application_n n loc f l =
let l1,l2 = list_chop (List.length l - n) l in
if l1 = [] then f,l else RApp (loc,f,l1), l2
-let match_alist match_fun metas sigma l1 l2 x iter termin lassoc =
- (* match the iterator at least once *)
- let sigmavar,sigmalist =
- List.fold_left2 (match_fun (ldots_var::metas)) sigma l1 l2 in
- (* Recover the recursive position *)
- let rest = List.assoc ldots_var sigmavar in
- (* Recover the first element *)
- let t1 = List.assoc x sigmavar in
- let sigmavar = List.remove_assoc x (List.remove_assoc ldots_var sigmavar) in
- (* try to find the remaining elements or the terminator *)
- let rec match_alist_tail metas sigma acc rest =
+let glue_letin_with_decls = true
+
+let rec match_iterated_binders islambda decls = function
+ | RLambda (_,na,bk,t,b) when islambda ->
+ match_iterated_binders islambda ((na,bk,None,t)::decls) b
+ | RProd (_,(Name _ as na),bk,t,b) when not islambda ->
+ match_iterated_binders islambda ((na,bk,None,t)::decls) b
+ | RLetIn (loc,na,c,b) when glue_letin_with_decls ->
+ match_iterated_binders islambda
+ ((na,Explicit (*?*), Some c,RHole(loc,Evd.BinderType na))::decls) b
+ | b -> (decls,b)
+
+let remove_sigma x (sigmavar,sigmalist,sigmabinders) =
+ (List.remove_assoc x sigmavar,sigmalist,sigmabinders)
+
+let rec match_abinderlist_with_app match_fun metas sigma rest x iter termin =
+ let rec aux sigma acc rest =
try
- let sigmavar,sigmalist = match_fun (ldots_var::metas) sigma rest iter in
- let rest = List.assoc ldots_var sigmavar in
- let t = List.assoc x sigmavar in
- let sigmavar =
- List.remove_assoc x (List.remove_assoc ldots_var sigmavar) in
- match_alist_tail metas (sigmavar,sigmalist) (t::acc) rest
- with No_match ->
- List.rev acc, match_fun metas (sigmavar,sigmalist) rest termin in
- let tl,(sigmavar,sigmalist) = match_alist_tail metas sigma [t1] rest in
- (sigmavar, (x,if lassoc then List.rev tl else tl)::sigmalist)
-
-let rec match_ alp metas sigma a1 a2 = match (a1,a2) with
- | r1, AVar id2 when List.mem id2 metas -> bind_env alp sigma id2 r1
+ let sigma = match_fun (ldots_var::fst metas,snd metas) sigma rest iter in
+ let rest = List.assoc ldots_var (pi1 sigma) in
+ let b = match List.assoc x (pi3 sigma) with [b] -> b | _ ->assert false in
+ let sigma = remove_sigma x (remove_sigma ldots_var sigma) in
+ aux sigma (b::acc) rest
+ with No_match when acc <> [] ->
+ acc, match_fun metas sigma rest termin in
+ let bl,sigma = aux sigma [] rest in
+ bind_binder sigma x bl
+
+let match_alist match_fun metas sigma rest x iter termin lassoc =
+ let rec aux sigma acc rest =
+ try
+ let sigma = match_fun (ldots_var::fst metas,snd metas) sigma rest iter in
+ let rest = List.assoc ldots_var (pi1 sigma) in
+ let t = List.assoc x (pi1 sigma) in
+ let sigma = remove_sigma x (remove_sigma ldots_var sigma) in
+ aux sigma (t::acc) rest
+ with No_match when acc <> [] ->
+ acc, match_fun metas sigma rest termin in
+ let l,sigma = aux sigma [] rest in
+ (pi1 sigma, (x,if lassoc then l else List.rev l)::pi2 sigma, pi3 sigma)
+
+let rec match_ alp (tmetas,blmetas as metas) sigma a1 a2 = match (a1,a2) with
+
+ (* Matching notation variable *)
+ | r1, AVar id2 when List.mem id2 tmetas -> bind_env alp sigma id2 r1
+
+ (* Matching recursive notations for terms *)
+ | r1, AList (x,_,iter,termin,lassoc) ->
+ match_alist (match_ alp) metas sigma r1 x iter termin lassoc
+
+ (* Matching recursive notations for binders: ad hoc cases supporting let-in *)
+ | RLambda (_,na1,bk,t1,b1), ABinderList (x,_,ALambda (Name id2,_,b2),termin)->
+ let (decls,b) = match_iterated_binders true [(na1,bk,None,t1)] b1 in
+ (* TODO: address the possibility that termin is a Lambda itself *)
+ match_ alp metas (bind_binder sigma x decls) b termin
+ | RProd (_,na1,bk,t1,b1), ABinderList (x,_,AProd (Name id2,_,b2),termin)
+ when na1 <> Anonymous ->
+ let (decls,b) = match_iterated_binders false [(na1,bk,None,t1)] b1 in
+ (* TODO: address the possibility that termin is a Prod itself *)
+ match_ alp metas (bind_binder sigma x decls) b termin
+ (* Matching recursive notations for binders: general case *)
+ | r, ABinderList (x,_,iter,termin) ->
+ match_abinderlist_with_app (match_ alp) metas sigma r x iter termin
+
+ (* Matching individual binders as part of a recursive pattern *)
+ | RLambda (_,na,bk,t,b1), ALambda (Name id,_,b2) when List.mem id blmetas ->
+ match_ alp metas (bind_binder sigma id [(na,bk,None,t)]) b1 b2
+ | RProd (_,na,bk,t,b1), AProd (Name id,_,b2)
+ when List.mem id blmetas & na <> Anonymous ->
+ match_ alp metas (bind_binder sigma id [(na,bk,None,t)]) b1 b2
+
+ (* Matching compositionally *)
| RVar (_,id1), AVar id2 when alpha_var id1 id2 alp -> sigma
- | RRef (_,r1), ARef r2 when (eq_gr r1 r2) -> sigma
+ | RRef (_,r1), ARef r2 when (eq_gr r1 r2) -> sigma
| RPatVar (_,(_,n1)), APatVar n2 when n1=n2 -> sigma
| RApp (loc,f1,l1), AApp (f2,l2) ->
let n1 = List.length l1 and n2 = List.length l2 in
@@ -519,11 +676,6 @@ let rec match_ alp metas sigma a1 a2 = match (a1,a2) with
let l11,l12 = list_chop (n1-n2) l1 in RApp (loc,f1,l11),l12, f2,l2
else f1,l1, f2, l2 in
List.fold_left2 (match_ alp metas) (match_ alp metas sigma f1 f2) l1 l2
- | RApp (loc,f1,l1), AList (x,_,(AApp (f2,l2) as iter),termin,lassoc)
- when List.length l1 >= List.length l2 ->
- let f1,l1 = adjust_application_n (List.length l2) loc f1 l1 in
- match_alist (match_ alp)
- metas sigma (f1::l1) (f2::l2) x iter termin lassoc
| RLambda (_,na1,_,t1,b1), ALambda (na2,t2,b2) ->
match_binders alp metas na1 na2 (match_ alp metas sigma t1 t2) b1 b2
| RProd (_,na1,_,t1,b1), AProd (na2,t2,b2) ->
@@ -588,38 +740,36 @@ and match_equations alp metas sigma (_,_,patl1,rhs1) (patl2,rhs2) =
(alp,sigma) patl1 patl2 in
match_ alp metas sigma rhs1 rhs2
-type scope_name = string
-
-type tmp_scope_name = scope_name
-
-type subscopes = tmp_scope_name option * scope_name list
-
-type interpretation =
- (* regular vars of notation / recursive parts of notation / body *)
- ((identifier * subscopes) list * (identifier * subscopes) list) * aconstr
-
-let match_aconstr c ((metas_scl,metaslist_scl),pat) =
- let vars = List.map fst metas_scl @ List.map fst metaslist_scl in
- let subst,substlist = match_ [] vars ([],[]) c pat in
+let match_aconstr c (metas,pat) =
+ let vars = list_split_by (fun (_,(_,x)) -> x <> NtnTypeBinderList) metas in
+ let vars = (List.map fst (fst vars), List.map fst (snd vars)) in
+ let terms,termlists,binders = match_ [] vars ([],[],[]) c pat in
(* Reorder canonically the substitution *)
let find x =
- try List.assoc x subst
+ try List.assoc x terms
with Not_found ->
(* Happens for binders bound to Anonymous *)
(* Find a better way to propagate Anonymous... *)
RVar (dummy_loc,x) in
- List.map (fun (x,scl) -> (find x,scl)) metas_scl,
- List.map (fun (x,scl) -> (List.assoc x substlist,scl)) metaslist_scl
+ List.fold_right (fun (x,(scl,typ)) (terms',termlists',binders') ->
+ match typ with
+ | NtnTypeConstr ->
+ ((find x, scl)::terms',termlists',binders')
+ | NtnTypeConstrList ->
+ (terms',(List.assoc x termlists,scl)::termlists',binders')
+ | NtnTypeBinderList ->
+ (terms',termlists',(List.assoc x binders,scl)::binders'))
+ metas ([],[],[])
(* Matching cases pattern *)
-let bind_env_cases_pattern (sigma,sigmalist as fullsigma) var v =
+let bind_env_cases_pattern (sigma,sigmalist,x as fullsigma) var v =
try
let vvar = List.assoc var sigma in
if v=vvar then fullsigma else raise No_match
with Not_found ->
(* TODO: handle the case of multiple occs in different scopes *)
- (var,v)::sigma,sigmalist
+ (var,v)::sigma,sigmalist,x
let rec match_cases_pattern metas sigma a1 a2 = match (a1,a2) with
| r1, AVar id2 when List.mem id2 metas -> bind_env_cases_pattern sigma id2 r1
@@ -639,26 +789,21 @@ let rec match_cases_pattern metas sigma a1 a2 = match (a1,a2) with
(* All parameters must be _ *)
List.iter (function AHole _ -> () | _ -> raise No_match) p2;
List.fold_left2 (match_cases_pattern metas) sigma args1 args2
- | PatCstr (loc,(ind,_ as r1),args1,_),
- AList (x,_,(AApp (ARef (ConstructRef r2),l2) as iter),termin,lassoc)
- when r1 = r2 ->
- let nparams = Inductive.inductive_params (Global.lookup_inductive ind) in
- assert (List.length args1 + nparams = List.length l2);
- let (p2,args2) = list_chop nparams l2 in
- List.iter (function AHole _ -> () | _ -> raise No_match) p2;
- match_alist match_cases_pattern
- metas sigma args1 args2 x iter termin lassoc
+ | r1, AList (x,_,iter,termin,lassoc) ->
+ match_alist (fun (metas,_) -> match_cases_pattern metas)
+ (metas,[]) (pi1 sigma,pi2 sigma,()) r1 x iter termin lassoc
| _ -> raise No_match
-let match_aconstr_cases_pattern c ((metas_scl,metaslist_scl),pat) =
- let vars = List.map fst metas_scl @ List.map fst metaslist_scl in
- let subst,substlist = match_cases_pattern vars ([],[]) c pat in
+let match_aconstr_cases_pattern c (metas,pat) =
+ let vars = List.map fst metas in
+ let terms,termlists,() = match_cases_pattern vars ([],[],()) c pat in
(* Reorder canonically the substitution *)
- let find x subst =
- try List.assoc x subst
- with Not_found -> anomaly "match_aconstr_cases_pattern" in
- List.map (fun (x,scl) -> (find x subst,scl)) metas_scl,
- List.map (fun (x,scl) -> (find x substlist,scl)) metaslist_scl
+ List.fold_right (fun (x,(scl,typ)) (terms',termlists') ->
+ match typ with
+ | NtnTypeConstr -> ((List.assoc x terms, scl)::terms',termlists')
+ | NtnTypeConstrList -> (terms',(List.assoc x termlists,scl)::termlists')
+ | NtnTypeBinderList -> assert false)
+ metas ([],[])
(**********************************************************************)
(*s Concrete syntax for terms *)
@@ -675,19 +820,20 @@ type proj_flag = int option (* [Some n] = proj of the n-th visible argument *)
type prim_token = Numeral of Bigint.bigint | String of string
-type 'a notation_substitution =
- 'a list * (* for recursive notations: *) 'a list list
-
type cases_pattern_expr =
| CPatAlias of loc * cases_pattern_expr * identifier
| CPatCstr of loc * reference * cases_pattern_expr list
| CPatAtom of loc * reference option
| CPatOr of loc * cases_pattern_expr list
- | CPatNotation of loc * notation * cases_pattern_expr notation_substitution
+ | CPatNotation of loc * notation * cases_pattern_notation_substitution
| CPatPrim of loc * prim_token
- | CPatRecord of loc * (reference * cases_pattern_expr) list
+ | CPatRecord of Util.loc * (reference * cases_pattern_expr) list
| CPatDelimiters of loc * string * cases_pattern_expr
+and cases_pattern_notation_substitution =
+ cases_pattern_expr list * (** for constr subterms *)
+ cases_pattern_expr list list (** for recursive notations *)
+
type constr_expr =
| CRef of reference
| CFix of loc * identifier located * fix_expr list
@@ -701,18 +847,18 @@ type constr_expr =
(constr_expr * explicitation located option) list
| CRecord of loc * constr_expr option * (reference * constr_expr) list
| CCases of loc * case_style * constr_expr option *
- (constr_expr * (name option * constr_expr option)) list *
+ (constr_expr * (name located option * constr_expr option)) list *
(loc * cases_pattern_expr list located list * constr_expr) list
- | CLetTuple of loc * name list * (name option * constr_expr option) *
+ | CLetTuple of loc * name located list * (name located option * constr_expr option) *
constr_expr * constr_expr
- | CIf of loc * constr_expr * (name option * constr_expr option)
+ | CIf of loc * constr_expr * (name located option * constr_expr option)
* constr_expr * constr_expr
| CHole of loc * Evd.hole_kind option
| CPatVar of loc * (bool * patvar)
| CEvar of loc * existential_key * constr_expr list option
| CSort of loc * rawsort
| CCast of loc * constr_expr * constr_expr cast_type
- | CNotation of loc * notation * constr_expr notation_substitution
+ | CNotation of loc * notation * constr_notation_substitution
| CGeneralization of loc * binding_kind * abstraction_kind option * constr_expr
| CPrim of loc * prim_token
| CDelimiters of loc * string * constr_expr
@@ -721,14 +867,6 @@ type constr_expr =
and fix_expr =
identifier located * (identifier located option * recursion_order_expr) * local_binder list * constr_expr * constr_expr
-and local_binder =
- | LocalRawDef of name located * constr_expr
- | LocalRawAssum of name located list * binder_kind * constr_expr
-
-and typeclass_constraint = name located * binding_kind * constr_expr
-
-and typeclass_context = typeclass_constraint list
-
and cofix_expr =
identifier located * local_binder list * constr_expr * constr_expr
@@ -737,6 +875,19 @@ and recursion_order_expr =
| CWfRec of constr_expr
| CMeasureRec of constr_expr * constr_expr option (* measure, relation *)
+and local_binder =
+ | LocalRawDef of name located * constr_expr
+ | LocalRawAssum of name located list * binder_kind * constr_expr
+
+and constr_notation_substitution =
+ constr_expr list * (* for constr subterms *)
+ constr_expr list list * (* for recursive notations *)
+ local_binder list list (* for binders subexpressions *)
+
+type typeclass_constraint = name located * binding_kind * constr_expr
+
+and typeclass_context = typeclass_constraint list
+
type constr_pattern_expr = constr_expr
(***********************)
@@ -789,6 +940,15 @@ let cases_pattern_expr_loc = function
| CPatPrim (loc,_) -> loc
| CPatDelimiters (loc,_,_) -> loc
+let local_binder_loc = function
+ | LocalRawAssum ((loc,_)::_,_,t)
+ | LocalRawDef ((loc,_),t) -> join_loc loc (constr_loc t)
+ | LocalRawAssum ([],_,_) -> assert false
+
+let local_binders_loc bll =
+ if bll = [] then dummy_loc else
+ join_loc (local_binder_loc (List.hd bll)) (local_binder_loc (list_last bll))
+
let occur_var_constr_ref id = function
| Ident (loc,id') -> id = id'
| Qualid _ -> false
@@ -798,7 +958,7 @@ let ids_of_cases_indtype =
let rec vars_of = function
(* We deal only with the regular cases *)
| CApp (_,_,l) -> List.fold_left add_var [] (List.map fst l)
- | CNotation (_,_,(l,[]))
+ | CNotation (_,_,(l,[],[]))
(* assume the ntn is applicative and does not instantiate the head !! *)
| CAppExpl (_,_,l) -> List.fold_left add_var [] l
| CDelimiters(_,_,c) -> vars_of c
@@ -809,7 +969,7 @@ let ids_of_cases_tomatch tms =
List.fold_right
(fun (_,(ona,indnal)) l ->
Option.fold_right (fun t -> (@) (ids_of_cases_indtype t))
- indnal (Option.fold_right name_cons ona l))
+ indnal (Option.fold_right (down_located name_cons) ona l))
tms []
let is_constructor id =
@@ -849,7 +1009,7 @@ let rec fold_local_binders g f n acc b = function
f n (fold_local_binders g f n' acc b l) t
| LocalRawDef ((_,na),t)::l ->
f n (fold_local_binders g f (name_fold g na n) acc b l) t
- | _ ->
+ | [] ->
f n acc b
let fold_constr_expr_with_binders g f n acc = function
@@ -860,7 +1020,11 @@ let fold_constr_expr_with_binders g f n acc = function
| CLetIn (_,na,a,b) -> fold_constr_expr_binders g f n acc b [[na],default_binder_kind,a]
| CCast (loc,a,CastConv(_,b)) -> f n (f n acc a) b
| CCast (loc,a,CastCoerce) -> f n acc a
- | CNotation (_,_,(l,ll)) -> List.fold_left (f n) acc (l@List.flatten ll)
+ | CNotation (_,_,(l,ll,bll)) ->
+ (* The following is an approximation: we don't know exactly if
+ an ident is binding nor to which subterms bindings apply *)
+ let acc = List.fold_left (f n) acc (l@List.flatten ll) in
+ List.fold_left (fun acc bl -> fold_local_binders g f n acc (CHole (dummy_loc,None)) bl) acc bll
| CGeneralization (_,_,_,c) -> f n acc c
| CDelimiters (loc,_,a) -> f n acc a
| CHole _ | CEvar _ | CPatVar _ | CSort _ | CPrim _ | CDynamic _ | CRef _ ->
@@ -874,11 +1038,12 @@ let fold_constr_expr_with_binders g f n acc = function
let ids = ids_of_pattern_list patl in
f (Idset.fold g ids n) acc rhs) bl acc
| CLetTuple (loc,nal,(ona,po),b,c) ->
- let n' = List.fold_right (name_fold g) nal n in
- f (Option.fold_right (name_fold g) ona n') (f n acc b) c
+ let n' = List.fold_right (down_located (name_fold g)) nal n in
+ f (Option.fold_right (down_located (name_fold g)) ona n') (f n acc b) c
| CIf (_,c,(ona,po),b1,b2) ->
let acc = f n (f n (f n acc b1) b2) c in
- Option.fold_left (f (Option.fold_right (name_fold g) ona n)) acc po
+ Option.fold_left
+ (f (Option.fold_right (down_located (name_fold g)) ona n)) acc po
| CFix (loc,_,l) ->
let n' = List.fold_right (fun ((_,id),_,_,_,_) -> g id) l n in
List.fold_right (fun (_,(_,o),lb,t,c) acc ->
@@ -961,21 +1126,29 @@ let coerce_to_name = function
(* Interpret the index of a recursion order annotation *)
-let index_of_annot bl na =
+let split_at_annot bl na =
let names = List.map snd (names_of_local_assums bl) in
match na with
| None ->
- if names = [] then error "A fixpoint needs at least one parameter."
- else None
+ if names = [] then error "A fixpoint needs at least one parameter."
+ else [], bl
| Some (loc, id) ->
- try Some (list_index0 (Name id) names)
- with Not_found ->
- user_err_loc(loc,"",
- str "No parameter named " ++ Nameops.pr_id id ++ str".")
+ let rec aux acc = function
+ | LocalRawAssum (bls, k, t) as x :: rest ->
+ let l, r = list_split_when (fun (loc, na) -> na = Name id) bls in
+ if r = [] then aux (x :: acc) rest
+ else
+ (List.rev (if l = [] then acc else LocalRawAssum (l, k, t) :: acc),
+ LocalRawAssum (r, k, t) :: rest)
+ | LocalRawDef _ as x :: rest -> aux (x :: acc) rest
+ | [] ->
+ user_err_loc(loc,"",
+ str "No parameter named " ++ Nameops.pr_id id ++ str".")
+ in aux [] bl
(* Used in correctness and interface *)
-let map_binder g e nal = List.fold_right (fun (_,na) -> name_fold g na) nal e
+let map_binder g e nal = List.fold_right (down_located (name_fold g)) nal e
let map_binders f g e bl =
(* TODO: avoid variable capture in [t] by some [na] in [List.tl nal] *)
@@ -1005,8 +1178,10 @@ let map_constr_expr_with_binders g f e = function
| CLetIn (loc,na,a,b) -> CLetIn (loc,na,f e a,f (name_fold g (snd na) e) b)
| CCast (loc,a,CastConv (k,b)) -> CCast (loc,f e a,CastConv(k, f e b))
| CCast (loc,a,CastCoerce) -> CCast (loc,f e a,CastCoerce)
- | CNotation (loc,n,(l,ll)) ->
- CNotation (loc,n,(List.map (f e) l,List.map (List.map (f e)) ll))
+ | CNotation (loc,n,(l,ll,bll)) ->
+ (* This is an approximation because we don't know what binds what *)
+ CNotation (loc,n,(List.map (f e) l,List.map (List.map (f e)) ll,
+ List.map (fun bl -> snd (map_local_binders f g e bl)) bll))
| CGeneralization (loc,b,a,c) -> CGeneralization (loc,b,a,f e c)
| CDelimiters (loc,s,a) -> CDelimiters (loc,s,f e a)
| CHole _ | CEvar _ | CPatVar _ | CSort _
@@ -1019,11 +1194,11 @@ let map_constr_expr_with_binders g f e = function
let po = Option.map (f (List.fold_right g ids e)) rtnpo in
CCases (loc, sty, po, List.map (fun (tm,x) -> (f e tm,x)) a,bl)
| CLetTuple (loc,nal,(ona,po),b,c) ->
- let e' = List.fold_right (name_fold g) nal e in
- let e'' = Option.fold_right (name_fold g) ona e in
+ let e' = List.fold_right (down_located (name_fold g)) nal e in
+ let e'' = Option.fold_right (down_located (name_fold g)) ona e in
CLetTuple (loc,nal,(ona,Option.map (f e'') po),f e b,f e' c)
| CIf (loc,c,(ona,po),b1,b2) ->
- let e' = Option.fold_right (name_fold g) ona e in
+ let e' = Option.fold_right (down_located (name_fold g)) ona e in
CIf (loc,f e c,(ona,Option.map (f e') po),f e b1,f e b2)
| CFix (loc,id,dl) ->
CFix (loc,id,List.map (fun (id,n,bl,t,d) ->
@@ -1067,16 +1242,21 @@ type 'a module_signature =
| Check of 'a list (* ... <: T1 <: T2, possibly empty *)
(* Returns the ranges of locs of the notation that are not occupied by args *)
-(* and which are them occupied by proper symbols of the notation (or spaces) *)
+(* and which are then occupied by proper symbols of the notation (or spaces) *)
-let locs_of_notation f loc (args,argslist) ntn =
+let locs_of_notation loc locs ntn =
let (bl,el) = Util.unloc loc in
+ let locs = List.map Util.unloc locs in
let rec aux pos = function
| [] -> if pos = el then [] else [(pos,el-1)]
- | a::l ->
- let ba,ea = Util.unloc (f a) in
- if pos = ba then aux ea l else (pos,ba-1)::aux ea l
- in aux bl (args@List.flatten argslist)
+ | (ba,ea)::l ->if pos = ba then aux ea l else (pos,ba-1)::aux ea l
+ in aux bl (Sort.list (fun l1 l2 -> fst l1 < fst l2) locs)
+
+let ntn_loc loc (args,argslist,binderslist) =
+ locs_of_notation loc
+ (List.map constr_loc (args@List.flatten argslist)@
+ List.map local_binders_loc binderslist)
-let ntn_loc = locs_of_notation constr_loc
-let patntn_loc = locs_of_notation cases_pattern_expr_loc
+let patntn_loc loc (args,argslist) =
+ locs_of_notation loc
+ (List.map cases_pattern_expr_loc (args@List.flatten argslist))