diff options
author | 2004-09-09 13:27:03 +0000 | |
---|---|---|
committer | 2004-09-09 13:27:03 +0000 | |
commit | ce7e64328553ac71f2630816cfb8baa930ea471e (patch) | |
tree | e0669306e771ebc7e5b9a7712212c39047e44d2e /interp | |
parent | 1795451a803ffa38d9b5cbf38d93cc7df687c11f (diff) |
Ajout de or-pattern pour le match-with v8
git-svn-id: svn+ssh://scm.gforge.inria.fr/svn/coq/trunk@6088 85f007b7-540e-0410-9357-904b9bb8a0f7
Diffstat (limited to 'interp')
-rw-r--r-- | interp/constrintern.ml | 120 | ||||
-rw-r--r-- | interp/topconstr.ml | 2 | ||||
-rw-r--r-- | interp/topconstr.mli | 1 |
3 files changed, 80 insertions, 43 deletions
diff --git a/interp/constrintern.ml b/interp/constrintern.ml index 326f8222e..03cd671af 100644 --- a/interp/constrintern.ml +++ b/interp/constrintern.ml @@ -355,6 +355,21 @@ let rec simple_adjust_scopes = function (**********************************************************************) (* Cases *) +let product_of_cases_patterns ids idspl = + List.fold_right (fun (ids,pl) (ids',ptaill) -> + (ids@ids', + (* Cartesian prod of the or-pats for the nth arg and the tail args *) + List.flatten ( + List.map (fun (subst,p) -> + List.map (fun (subst',ptail) -> (subst@subst',p::ptail)) ptaill) pl))) + idspl (ids,[[],[]]) + +let simple_product_of_cases_patterns pl = + List.fold_right (fun pl ptaill -> + List.flatten (List.map (fun (subst,p) -> + List.map (fun (subst',ptail) -> (subst@subst',p::ptail)) ptaill) pl)) + pl [[],[]] + (* Check linearity of pattern-matching *) let rec has_duplicate = function | [] -> None @@ -391,6 +406,11 @@ let check_number_of_pattern loc n l = let p = List.length l in if n<>p then raise (InternalisationError (loc,BadPatternsNumber (n,p))) +let check_or_pat_variables loc ids idsl = + if List.exists ((<>) ids) idsl then + user_err_loc (loc, "", str + "The components of this disjunctive pattern must bind the same variables") + (* Manage multiple aliases *) (* [merge_aliases] returns the sets of all aliases encountered at this @@ -412,13 +432,15 @@ let decode_patlist_value = function | CPatCstr (_,_,l) -> l | _ -> anomaly "Ill-formed list argument of notation" -let rec subst_pat_iterator y t = function +let rec subst_pat_iterator y t (subst,p) = match p with | PatVar (_,id) as x -> - if id = Name y then t else x + if id = Name y then t else [subst,x] | PatCstr (loc,id,l,alias) -> - PatCstr (loc,id,List.map (subst_pat_iterator y t) l,alias) + let l' = List.map (fun a -> (subst_pat_iterator y t ([],a))) l in + let pl = simple_product_of_cases_patterns l' in + List.map (fun (subst',pl) -> subst'@subst,PatCstr (loc,id,pl,alias)) pl -let subst_cases_pattern loc aliases intern subst scopes a = +let subst_cases_pattern loc (ids,asubst as aliases) intern subst scopes a = let rec aux aliases subst = function | AVar id -> begin @@ -428,7 +450,7 @@ let subst_cases_pattern loc aliases intern subst scopes a = let (a,(scopt,subscopes)) = List.assoc id subst in intern (subscopes@scopes) ([],[]) scopt a with Not_found -> - if id = ldots_var then [[],[]], PatVar (loc,Name id) else + if id = ldots_var then [],[[], PatVar (loc,Name id)] else anomaly ("Unbound pattern notation variable: "^(string_of_id id)) (* (* Happens for local notation joint with inductive/fixpoint defs *) @@ -438,24 +460,28 @@ let subst_cases_pattern loc aliases intern subst scopes a = *) end | ARef (ConstructRef c) -> - [aliases], PatCstr (loc,c, [], alias_of aliases) + (ids,[asubst, PatCstr (loc,c, [], alias_of aliases)]) | AApp (ARef (ConstructRef (ind,_ as c)),args) -> let nparams = (snd (Global.lookup_inductive ind)).Declarations.mind_nparams in let _,args = list_chop nparams args in - let (idsl,pl) = List.split (List.map (aux ([],[]) subst) args) in - aliases::List.flatten idsl, PatCstr (loc,c,pl,alias_of aliases) + let idslpll = List.map (aux ([],[]) subst) args in + let ids',pll = product_of_cases_patterns ids idslpll in + let pl' = List.map (fun (subst,pl) -> + subst,PatCstr (loc,c,pl,alias_of aliases)) pll in + ids', pl' | AList (x,_,iter,terminator,lassoc) -> (try (* All elements of the list are in scopes (scopt,subscopes) *) let (a,(scopt,subscopes)) = List.assoc x subst in - let idslt,termin = aux ([],[]) subst terminator in + let termin = aux ([],[]) subst terminator in let l = decode_patlist_value a in let idsl,v = - List.fold_right (fun a (allidsl,t) -> - let idsl,u = aux ([],[]) ((x,(a,(scopt,subscopes)))::subst) iter in - idsl::allidsl, subst_pat_iterator ldots_var t u) - (if lassoc then List.rev l else l) ([idslt],termin) in - aliases::List.flatten idsl, v + List.fold_right (fun a (tids,t) -> + let uids,u = aux ([],[]) ((x,(a,(scopt,subscopes)))::subst) iter in + let pll = List.map (subst_pat_iterator ldots_var t) u in + tids@uids, List.flatten pll) + (if lassoc then List.rev l else l) termin in + ids@idsl, v with Not_found -> anomaly "Inconsistent substitution of recursive notation") | t -> user_err_loc (loc,"",str "Invalid notation for pattern") @@ -531,7 +557,7 @@ let mustbe_constructor loc ref = with (Environ.NotEvaluableConst _ | Not_found) -> raise (InternalisationError (loc,NotAConstructor ref)) -let rec intern_cases_pattern scopes aliases tmp_scope = function +let rec intern_cases_pattern scopes (ids,subst as aliases) tmp_scope = function | CPatAlias (loc, p, id) -> let aliases' = merge_aliases aliases id in intern_cases_pattern scopes aliases' tmp_scope p @@ -539,15 +565,16 @@ let rec intern_cases_pattern scopes aliases tmp_scope = function let c,pl0 = mustbe_constructor loc head in let argscs = simple_adjust_scopes (find_arguments_scope (ConstructRef c), pl) in - let (idsl,pl') = - List.split (List.map2 (intern_cases_pattern scopes ([],[])) argscs pl) - in - (aliases::(List.flatten idsl), PatCstr (loc,c,pl0@pl',alias_of aliases)) + let idslpl = List.map2 (intern_cases_pattern scopes ([],[])) argscs pl in + let (ids',pll) = product_of_cases_patterns ids idslpl in + let pl' = List.map (fun (subst,pl) -> + (subst, PatCstr (loc,c,pl0@pl,alias_of aliases))) pll in + ids',pl' | CPatNotation (loc,"- _",[CPatNumeral(_,Bignat.POS p)]) -> let scopes = option_cons tmp_scope scopes in - ([aliases], - Symbols.interp_numeral_as_pattern loc (Bignat.NEG p) - (alias_of aliases) scopes) + (ids, + [subst, Symbols.interp_numeral_as_pattern loc (Bignat.NEG p) + (alias_of aliases) scopes]) | CPatNotation (_,"( _ )",[a]) -> intern_cases_pattern scopes aliases tmp_scope a | CPatNotation (loc, ntn, args) -> @@ -559,20 +586,27 @@ let rec intern_cases_pattern scopes aliases tmp_scope = function subst_cases_pattern loc aliases intern_cases_pattern subst scopes c | CPatNumeral (loc, n) -> let scopes = option_cons tmp_scope scopes in - ([aliases], - Symbols.interp_numeral_as_pattern loc n (alias_of aliases) scopes) + (ids,[subst, + Symbols.interp_numeral_as_pattern loc n (alias_of aliases) scopes]) | CPatDelimiters (loc, key, e) -> intern_cases_pattern (find_delimiters_scope loc key::scopes) aliases None e | CPatAtom (loc, Some head) -> (match maybe_constructor head with | ConstrPat (c,args) -> - ([aliases], PatCstr (loc,c,args,alias_of aliases)) + (ids,[subst, PatCstr (loc,c,args,alias_of aliases)]) | VarPat id -> - let aliases = merge_aliases aliases id in - ([aliases], PatVar (loc,alias_of aliases))) + let ids,subst = merge_aliases aliases id in + (ids,[subst, PatVar (loc,alias_of (ids,subst))])) | CPatAtom (loc, None) -> - ([aliases], PatVar (loc,alias_of aliases)) + (ids,[subst, PatVar (loc,alias_of aliases)]) + | CPatOr (loc, pl) -> + assert (pl <> []); + let pl' = List.map (intern_cases_pattern scopes aliases tmp_scope) pl in + let (idsl,pl') = List.split pl' in + let ids = List.hd idsl in + check_or_pat_variables loc ids (List.tl idsl); + (ids,List.flatten pl') (**********************************************************************) (* Fix and CoFix *) @@ -852,8 +886,9 @@ let internalise sigma env allow_soapp lvar c = (tm,ref ind)::inds,List.fold_left (push_name_env lvar) env nal) tms ([],env) in let rtnpo = option_app (intern_type env') rtnpo in + let eqns' = List.map (intern_eqn (List.length tms) env) eqns in RCases (loc, (option_app (intern_type env) po, ref rtnpo), tms, - List.map (intern_eqn (List.length tms) env) eqns) + List.flatten eqns') | COrderedCase (loc, tag, po, c, cl) -> let env = reset_tmp_scope env in ROrderedCase (loc, tag, option_app (intern_type env) po, @@ -911,20 +946,19 @@ let internalise sigma env allow_soapp lvar c = (na,Some(intern env def),RHole(loc,BinderType na))::bl) and intern_eqn n (ids,tmp_scope,scopes as env) (loc,lhs,rhs) = - let (idsl_substl_list,pl) = - List.split - (List.map (intern_cases_pattern scopes ([],[]) None) lhs) in - let idsl, substl = List.split (List.flatten idsl_substl_list) in - let eqn_ids = List.flatten idsl in - let subst = List.flatten substl in - (* Linearity implies the order in ids is irrelevant *) - check_linearity lhs eqn_ids; - check_uppercase loc eqn_ids; - check_number_of_pattern loc n pl; - let rhs = replace_vars_constr_expr subst rhs in - List.iter message_redundant_alias subst; - let env_ids = List.fold_right Idset.add eqn_ids ids in - (loc, eqn_ids,pl,intern (env_ids,tmp_scope,scopes) rhs) + let idsl_pll = List.map (intern_cases_pattern scopes ([],[]) None) lhs in + + let eqn_ids,pll = product_of_cases_patterns [] idsl_pll in + (* Linearity implies the order in ids is irrelevant *) + check_linearity lhs eqn_ids; + check_uppercase loc eqn_ids; + check_number_of_pattern loc n (snd (List.hd pll)); + let env_ids = List.fold_right Idset.add eqn_ids ids in + List.map (fun (subst,pl) -> + let rhs = replace_vars_constr_expr subst rhs in + List.iter message_redundant_alias subst; + let rhs' = intern (env_ids,tmp_scope,scopes) rhs in + (loc,eqn_ids,pl,rhs')) pll and intern_case_item (vars,_,scopes as env) (tm,(na,t)) = let tm' = intern env tm in diff --git a/interp/topconstr.ml b/interp/topconstr.ml index c92d790b1..d3b72ef78 100644 --- a/interp/topconstr.ml +++ b/interp/topconstr.ml @@ -481,6 +481,7 @@ type cases_pattern_expr = | CPatAlias of loc * cases_pattern_expr * identifier | CPatCstr of loc * reference * cases_pattern_expr list | CPatAtom of loc * reference option + | CPatOr of loc * cases_pattern_expr list | CPatNotation of loc * notation * cases_pattern_expr list | CPatNumeral of loc * Bignat.bigint | CPatDelimiters of loc * string * cases_pattern_expr @@ -568,6 +569,7 @@ let cases_pattern_loc = function | CPatAlias (loc,_,_) -> loc | CPatCstr (loc,_,_) -> loc | CPatAtom (loc,_) -> loc + | CPatOr (loc,_) -> loc | CPatNotation (loc,_,_) -> loc | CPatNumeral (loc,_) -> loc | CPatDelimiters (loc,_,_) -> loc diff --git a/interp/topconstr.mli b/interp/topconstr.mli index fbdd57ca6..d5046b43b 100644 --- a/interp/topconstr.mli +++ b/interp/topconstr.mli @@ -74,6 +74,7 @@ type cases_pattern_expr = | CPatAlias of loc * cases_pattern_expr * identifier | CPatCstr of loc * reference * cases_pattern_expr list | CPatAtom of loc * reference option + | CPatOr of loc * cases_pattern_expr list | CPatNotation of loc * notation * cases_pattern_expr list | CPatNumeral of loc * Bignat.bigint | CPatDelimiters of loc * string * cases_pattern_expr |