aboutsummaryrefslogtreecommitdiffhomepage
path: root/interp
diff options
context:
space:
mode:
authorGravatar herbelin <herbelin@85f007b7-540e-0410-9357-904b9bb8a0f7>2004-09-09 13:27:03 +0000
committerGravatar herbelin <herbelin@85f007b7-540e-0410-9357-904b9bb8a0f7>2004-09-09 13:27:03 +0000
commitce7e64328553ac71f2630816cfb8baa930ea471e (patch)
treee0669306e771ebc7e5b9a7712212c39047e44d2e /interp
parent1795451a803ffa38d9b5cbf38d93cc7df687c11f (diff)
Ajout de or-pattern pour le match-with v8
git-svn-id: svn+ssh://scm.gforge.inria.fr/svn/coq/trunk@6088 85f007b7-540e-0410-9357-904b9bb8a0f7
Diffstat (limited to 'interp')
-rw-r--r--interp/constrintern.ml120
-rw-r--r--interp/topconstr.ml2
-rw-r--r--interp/topconstr.mli1
3 files changed, 80 insertions, 43 deletions
diff --git a/interp/constrintern.ml b/interp/constrintern.ml
index 326f8222e..03cd671af 100644
--- a/interp/constrintern.ml
+++ b/interp/constrintern.ml
@@ -355,6 +355,21 @@ let rec simple_adjust_scopes = function
(**********************************************************************)
(* Cases *)
+let product_of_cases_patterns ids idspl =
+ List.fold_right (fun (ids,pl) (ids',ptaill) ->
+ (ids@ids',
+ (* Cartesian prod of the or-pats for the nth arg and the tail args *)
+ List.flatten (
+ List.map (fun (subst,p) ->
+ List.map (fun (subst',ptail) -> (subst@subst',p::ptail)) ptaill) pl)))
+ idspl (ids,[[],[]])
+
+let simple_product_of_cases_patterns pl =
+ List.fold_right (fun pl ptaill ->
+ List.flatten (List.map (fun (subst,p) ->
+ List.map (fun (subst',ptail) -> (subst@subst',p::ptail)) ptaill) pl))
+ pl [[],[]]
+
(* Check linearity of pattern-matching *)
let rec has_duplicate = function
| [] -> None
@@ -391,6 +406,11 @@ let check_number_of_pattern loc n l =
let p = List.length l in
if n<>p then raise (InternalisationError (loc,BadPatternsNumber (n,p)))
+let check_or_pat_variables loc ids idsl =
+ if List.exists ((<>) ids) idsl then
+ user_err_loc (loc, "", str
+ "The components of this disjunctive pattern must bind the same variables")
+
(* Manage multiple aliases *)
(* [merge_aliases] returns the sets of all aliases encountered at this
@@ -412,13 +432,15 @@ let decode_patlist_value = function
| CPatCstr (_,_,l) -> l
| _ -> anomaly "Ill-formed list argument of notation"
-let rec subst_pat_iterator y t = function
+let rec subst_pat_iterator y t (subst,p) = match p with
| PatVar (_,id) as x ->
- if id = Name y then t else x
+ if id = Name y then t else [subst,x]
| PatCstr (loc,id,l,alias) ->
- PatCstr (loc,id,List.map (subst_pat_iterator y t) l,alias)
+ let l' = List.map (fun a -> (subst_pat_iterator y t ([],a))) l in
+ let pl = simple_product_of_cases_patterns l' in
+ List.map (fun (subst',pl) -> subst'@subst,PatCstr (loc,id,pl,alias)) pl
-let subst_cases_pattern loc aliases intern subst scopes a =
+let subst_cases_pattern loc (ids,asubst as aliases) intern subst scopes a =
let rec aux aliases subst = function
| AVar id ->
begin
@@ -428,7 +450,7 @@ let subst_cases_pattern loc aliases intern subst scopes a =
let (a,(scopt,subscopes)) = List.assoc id subst in
intern (subscopes@scopes) ([],[]) scopt a
with Not_found ->
- if id = ldots_var then [[],[]], PatVar (loc,Name id) else
+ if id = ldots_var then [],[[], PatVar (loc,Name id)] else
anomaly ("Unbound pattern notation variable: "^(string_of_id id))
(*
(* Happens for local notation joint with inductive/fixpoint defs *)
@@ -438,24 +460,28 @@ let subst_cases_pattern loc aliases intern subst scopes a =
*)
end
| ARef (ConstructRef c) ->
- [aliases], PatCstr (loc,c, [], alias_of aliases)
+ (ids,[asubst, PatCstr (loc,c, [], alias_of aliases)])
| AApp (ARef (ConstructRef (ind,_ as c)),args) ->
let nparams = (snd (Global.lookup_inductive ind)).Declarations.mind_nparams in
let _,args = list_chop nparams args in
- let (idsl,pl) = List.split (List.map (aux ([],[]) subst) args) in
- aliases::List.flatten idsl, PatCstr (loc,c,pl,alias_of aliases)
+ let idslpll = List.map (aux ([],[]) subst) args in
+ let ids',pll = product_of_cases_patterns ids idslpll in
+ let pl' = List.map (fun (subst,pl) ->
+ subst,PatCstr (loc,c,pl,alias_of aliases)) pll in
+ ids', pl'
| AList (x,_,iter,terminator,lassoc) ->
(try
(* All elements of the list are in scopes (scopt,subscopes) *)
let (a,(scopt,subscopes)) = List.assoc x subst in
- let idslt,termin = aux ([],[]) subst terminator in
+ let termin = aux ([],[]) subst terminator in
let l = decode_patlist_value a in
let idsl,v =
- List.fold_right (fun a (allidsl,t) ->
- let idsl,u = aux ([],[]) ((x,(a,(scopt,subscopes)))::subst) iter in
- idsl::allidsl, subst_pat_iterator ldots_var t u)
- (if lassoc then List.rev l else l) ([idslt],termin) in
- aliases::List.flatten idsl, v
+ List.fold_right (fun a (tids,t) ->
+ let uids,u = aux ([],[]) ((x,(a,(scopt,subscopes)))::subst) iter in
+ let pll = List.map (subst_pat_iterator ldots_var t) u in
+ tids@uids, List.flatten pll)
+ (if lassoc then List.rev l else l) termin in
+ ids@idsl, v
with Not_found ->
anomaly "Inconsistent substitution of recursive notation")
| t -> user_err_loc (loc,"",str "Invalid notation for pattern")
@@ -531,7 +557,7 @@ let mustbe_constructor loc ref =
with (Environ.NotEvaluableConst _ | Not_found) ->
raise (InternalisationError (loc,NotAConstructor ref))
-let rec intern_cases_pattern scopes aliases tmp_scope = function
+let rec intern_cases_pattern scopes (ids,subst as aliases) tmp_scope = function
| CPatAlias (loc, p, id) ->
let aliases' = merge_aliases aliases id in
intern_cases_pattern scopes aliases' tmp_scope p
@@ -539,15 +565,16 @@ let rec intern_cases_pattern scopes aliases tmp_scope = function
let c,pl0 = mustbe_constructor loc head in
let argscs =
simple_adjust_scopes (find_arguments_scope (ConstructRef c), pl) in
- let (idsl,pl') =
- List.split (List.map2 (intern_cases_pattern scopes ([],[])) argscs pl)
- in
- (aliases::(List.flatten idsl), PatCstr (loc,c,pl0@pl',alias_of aliases))
+ let idslpl = List.map2 (intern_cases_pattern scopes ([],[])) argscs pl in
+ let (ids',pll) = product_of_cases_patterns ids idslpl in
+ let pl' = List.map (fun (subst,pl) ->
+ (subst, PatCstr (loc,c,pl0@pl,alias_of aliases))) pll in
+ ids',pl'
| CPatNotation (loc,"- _",[CPatNumeral(_,Bignat.POS p)]) ->
let scopes = option_cons tmp_scope scopes in
- ([aliases],
- Symbols.interp_numeral_as_pattern loc (Bignat.NEG p)
- (alias_of aliases) scopes)
+ (ids,
+ [subst, Symbols.interp_numeral_as_pattern loc (Bignat.NEG p)
+ (alias_of aliases) scopes])
| CPatNotation (_,"( _ )",[a]) ->
intern_cases_pattern scopes aliases tmp_scope a
| CPatNotation (loc, ntn, args) ->
@@ -559,20 +586,27 @@ let rec intern_cases_pattern scopes aliases tmp_scope = function
subst_cases_pattern loc aliases intern_cases_pattern subst scopes c
| CPatNumeral (loc, n) ->
let scopes = option_cons tmp_scope scopes in
- ([aliases],
- Symbols.interp_numeral_as_pattern loc n (alias_of aliases) scopes)
+ (ids,[subst,
+ Symbols.interp_numeral_as_pattern loc n (alias_of aliases) scopes])
| CPatDelimiters (loc, key, e) ->
intern_cases_pattern (find_delimiters_scope loc key::scopes)
aliases None e
| CPatAtom (loc, Some head) ->
(match maybe_constructor head with
| ConstrPat (c,args) ->
- ([aliases], PatCstr (loc,c,args,alias_of aliases))
+ (ids,[subst, PatCstr (loc,c,args,alias_of aliases)])
| VarPat id ->
- let aliases = merge_aliases aliases id in
- ([aliases], PatVar (loc,alias_of aliases)))
+ let ids,subst = merge_aliases aliases id in
+ (ids,[subst, PatVar (loc,alias_of (ids,subst))]))
| CPatAtom (loc, None) ->
- ([aliases], PatVar (loc,alias_of aliases))
+ (ids,[subst, PatVar (loc,alias_of aliases)])
+ | CPatOr (loc, pl) ->
+ assert (pl <> []);
+ let pl' = List.map (intern_cases_pattern scopes aliases tmp_scope) pl in
+ let (idsl,pl') = List.split pl' in
+ let ids = List.hd idsl in
+ check_or_pat_variables loc ids (List.tl idsl);
+ (ids,List.flatten pl')
(**********************************************************************)
(* Fix and CoFix *)
@@ -852,8 +886,9 @@ let internalise sigma env allow_soapp lvar c =
(tm,ref ind)::inds,List.fold_left (push_name_env lvar) env nal)
tms ([],env) in
let rtnpo = option_app (intern_type env') rtnpo in
+ let eqns' = List.map (intern_eqn (List.length tms) env) eqns in
RCases (loc, (option_app (intern_type env) po, ref rtnpo), tms,
- List.map (intern_eqn (List.length tms) env) eqns)
+ List.flatten eqns')
| COrderedCase (loc, tag, po, c, cl) ->
let env = reset_tmp_scope env in
ROrderedCase (loc, tag, option_app (intern_type env) po,
@@ -911,20 +946,19 @@ let internalise sigma env allow_soapp lvar c =
(na,Some(intern env def),RHole(loc,BinderType na))::bl)
and intern_eqn n (ids,tmp_scope,scopes as env) (loc,lhs,rhs) =
- let (idsl_substl_list,pl) =
- List.split
- (List.map (intern_cases_pattern scopes ([],[]) None) lhs) in
- let idsl, substl = List.split (List.flatten idsl_substl_list) in
- let eqn_ids = List.flatten idsl in
- let subst = List.flatten substl in
- (* Linearity implies the order in ids is irrelevant *)
- check_linearity lhs eqn_ids;
- check_uppercase loc eqn_ids;
- check_number_of_pattern loc n pl;
- let rhs = replace_vars_constr_expr subst rhs in
- List.iter message_redundant_alias subst;
- let env_ids = List.fold_right Idset.add eqn_ids ids in
- (loc, eqn_ids,pl,intern (env_ids,tmp_scope,scopes) rhs)
+ let idsl_pll = List.map (intern_cases_pattern scopes ([],[]) None) lhs in
+
+ let eqn_ids,pll = product_of_cases_patterns [] idsl_pll in
+ (* Linearity implies the order in ids is irrelevant *)
+ check_linearity lhs eqn_ids;
+ check_uppercase loc eqn_ids;
+ check_number_of_pattern loc n (snd (List.hd pll));
+ let env_ids = List.fold_right Idset.add eqn_ids ids in
+ List.map (fun (subst,pl) ->
+ let rhs = replace_vars_constr_expr subst rhs in
+ List.iter message_redundant_alias subst;
+ let rhs' = intern (env_ids,tmp_scope,scopes) rhs in
+ (loc,eqn_ids,pl,rhs')) pll
and intern_case_item (vars,_,scopes as env) (tm,(na,t)) =
let tm' = intern env tm in
diff --git a/interp/topconstr.ml b/interp/topconstr.ml
index c92d790b1..d3b72ef78 100644
--- a/interp/topconstr.ml
+++ b/interp/topconstr.ml
@@ -481,6 +481,7 @@ type cases_pattern_expr =
| CPatAlias of loc * cases_pattern_expr * identifier
| CPatCstr of loc * reference * cases_pattern_expr list
| CPatAtom of loc * reference option
+ | CPatOr of loc * cases_pattern_expr list
| CPatNotation of loc * notation * cases_pattern_expr list
| CPatNumeral of loc * Bignat.bigint
| CPatDelimiters of loc * string * cases_pattern_expr
@@ -568,6 +569,7 @@ let cases_pattern_loc = function
| CPatAlias (loc,_,_) -> loc
| CPatCstr (loc,_,_) -> loc
| CPatAtom (loc,_) -> loc
+ | CPatOr (loc,_) -> loc
| CPatNotation (loc,_,_) -> loc
| CPatNumeral (loc,_) -> loc
| CPatDelimiters (loc,_,_) -> loc
diff --git a/interp/topconstr.mli b/interp/topconstr.mli
index fbdd57ca6..d5046b43b 100644
--- a/interp/topconstr.mli
+++ b/interp/topconstr.mli
@@ -74,6 +74,7 @@ type cases_pattern_expr =
| CPatAlias of loc * cases_pattern_expr * identifier
| CPatCstr of loc * reference * cases_pattern_expr list
| CPatAtom of loc * reference option
+ | CPatOr of loc * cases_pattern_expr list
| CPatNotation of loc * notation * cases_pattern_expr list
| CPatNumeral of loc * Bignat.bigint
| CPatDelimiters of loc * string * cases_pattern_expr