aboutsummaryrefslogtreecommitdiffhomepage
path: root/plugins
diff options
context:
space:
mode:
authorGravatar letouzey <letouzey@85f007b7-540e-0410-9357-904b9bb8a0f7>2010-07-08 14:07:48 +0000
committerGravatar letouzey <letouzey@85f007b7-540e-0410-9357-904b9bb8a0f7>2010-07-08 14:07:48 +0000
commit44e80ba4e4c863e0c38cc7cf6a65579640385436 (patch)
tree8240bcf69da4aa761fcb49516a5dc50038760d9e /plugins
parent9a69dd9a3f0b4ffe94d5ef4670858f7a7e99f405 (diff)
Extraction: more factorization of common match branches
In addition to the "| _ -> cst" situation, now we can also reconstruct a "| e -> f e" final branch. For instance, this has a tremenduous effect on Compcert/backend/Selection.v. NB: The "fun" factorisation is almost more general than the "cst" situation, but not always. Think of A=>A|B=>A, 1st branch will be recognized as (fun x->x), not (fun x->A). We also add a fine detection of inductive types with phantom type variables, for which this optimisation is type-unsafe. git-svn-id: svn+ssh://scm.gforge.inria.fr/svn/coq/trunk@13267 85f007b7-540e-0410-9357-904b9bb8a0f7
Diffstat (limited to 'plugins')
-rw-r--r--plugins/extraction/extraction.ml5
-rw-r--r--plugins/extraction/haskell.ml28
-rw-r--r--plugins/extraction/miniml.mli13
-rw-r--r--plugins/extraction/mlutil.ml183
-rw-r--r--plugins/extraction/mlutil.mli4
-rw-r--r--plugins/extraction/ocaml.ml28
6 files changed, 172 insertions, 89 deletions
diff --git a/plugins/extraction/extraction.ml b/plugins/extraction/extraction.ml
index c0097d492..a4ec24ba2 100644
--- a/plugins/extraction/extraction.ml
+++ b/plugins/extraction/extraction.ml
@@ -360,7 +360,8 @@ and extract_ind env kn = (* kn is supposed to be in long form *)
ip_logical = (not b);
ip_sign = s;
ip_vars = v;
- ip_types = t })
+ ip_types = t;
+ ip_optim_id_ok = None })
mib.mind_packets
in
@@ -787,7 +788,7 @@ and extract_case env mle ((kn,i) as ip,c,br) mlt =
end
else
(* Standard case: we apply [extract_branch]. *)
- MLcase ((mi.ind_info,[]), a, Array.init br_size extract_branch)
+ MLcase ((mi.ind_info,BranchNone), a, Array.init br_size extract_branch)
(*s Extraction of a (co)-fixpoint. *)
diff --git a/plugins/extraction/haskell.ml b/plugins/extraction/haskell.ml
index d42840613..97f49d833 100644
--- a/plugins/extraction/haskell.ml
+++ b/plugins/extraction/haskell.ml
@@ -182,18 +182,28 @@ and pp_pat env factors pv =
(fun () -> (spc ())) pr_id (List.rev ids))) ++
str " ->" ++ spc () ++ pp_expr par env' [] t)
in
+ let factor_br, factor_l = try match factors with
+ | BranchFun (i::_ as l) -> check_function_branch pv.(i), l
+ | BranchCst (i::_ as l) -> ast_pop (check_constant_branch pv.(i)), l
+ | _ -> MLdummy, []
+ with Impossible -> MLdummy, []
+ in
+ let par = expr_needs_par factor_br in
+ let last = Array.length pv - 1 in
prvecti
- (fun i x -> if List.mem i factors then mt () else
+ (fun i x -> if List.mem i factor_l then mt () else
(pp_one_pat pv.(i) ++
- if factors = [] && i = Array.length pv - 1 then mt ()
- else fnl () ++ str " ")) pv
+ if i = last && factor_l = [] then mt () else
+ fnl () ++ str " ")) pv
++
- match factors with
- | [] -> mt ()
- | i::_ ->
- let (_,ids,t) = pv.(i) in
- let t = ast_lift (-List.length ids) t in
- hov 2 (str "_ ->" ++ spc () ++ pp_expr (expr_needs_par t) env [] t)
+ if factor_l = [] then mt () else match factors with
+ | BranchFun _ ->
+ let ids, env' = push_vars [anonymous_name] env in
+ pr_id (List.hd ids) ++ str " ->" ++ spc () ++
+ pp_expr par env' [] factor_br
+ | BranchCst _ ->
+ str "_ ->" ++ spc () ++ pp_expr par env [] factor_br
+ | BranchNone -> mt ()
(*s names of the functions ([ids]) are already pushed in [env],
and passed here just for convenience. *)
diff --git a/plugins/extraction/miniml.mli b/plugins/extraction/miniml.mli
index a27a9cf03..410054624 100644
--- a/plugins/extraction/miniml.mli
+++ b/plugins/extraction/miniml.mli
@@ -57,8 +57,6 @@ type inductive_info =
| Standard
| Record of global_reference list
-type case_info = int list (* list of branches to merge in a _ pattern *)
-
(* A [ml_ind_packet] is the miniml counterpart of a [one_inductive_body].
If the inductive is logical ([ip_logical = false]), then all other fields
are unused. Otherwise,
@@ -73,7 +71,9 @@ type ml_ind_packet = {
ip_logical : bool;
ip_sign : signature;
ip_vars : identifier list;
- ip_types : (ml_type list) array }
+ ip_types : (ml_type list) array;
+ mutable ip_optim_id_ok : bool option
+}
(* [ip_nparams] contains the number of parameters. *)
@@ -96,6 +96,13 @@ type ml_ident =
| Id of identifier
| Tmp of identifier
+(* list of branches to merge in a common pattern *)
+
+type case_info =
+ | BranchNone
+ | BranchFun of int list
+ | BranchCst of int list
+
type ml_branch = global_reference * ml_ident list * ml_ast
and ml_ast =
diff --git a/plugins/extraction/mlutil.ml b/plugins/extraction/mlutil.ml
index 1cd226616..6a5a83b1d 100644
--- a/plugins/extraction/mlutil.ml
+++ b/plugins/extraction/mlutil.ml
@@ -236,6 +236,21 @@ let type_maxvar t =
| _ -> n
in parse 0 t
+(*s What are the type variables occurring in [t]. *)
+
+let intset_union_map_list f l =
+ List.fold_left (fun s t -> Intset.union s (f t)) Intset.empty l
+
+let intset_union_map_array f a =
+ Array.fold_left (fun s t -> Intset.union s (f t)) Intset.empty a
+
+let rec type_listvar = function
+ | Tmeta {contents = Some t} -> type_listvar t
+ | Tvar i | Tvar' i -> Intset.singleton i
+ | Tarr (a,b) -> Intset.union (type_listvar a) (type_listvar b)
+ | Tglob (_,l) -> intset_union_map_list type_listvar l
+ | _ -> Intset.empty
+
(*s From [a -> b -> c] to [[a;b],c]. *)
let rec type_decomp = function
@@ -648,11 +663,12 @@ let rec ast_glob_subst s t = match t with
(*S Auxiliary functions used in simplification of ML cases. *)
-(*s [check_and_generalize (r0,l,c)] transforms any [MLcons(r0,l)] in [MLrel 1]
- and raises [Impossible] if any variable in [l] occurs outside such a
- [MLcons] *)
+(*s [check_function_branch (r,l,c)] checks if branch [c] can be seen
+ as a function [f] applied to [MLcons(r,l)]. For that it transforms
+ any [MLcons(r,l)] in [MLrel 1] and raises [Impossible] if any
+ variable in [l] occurs outside such a [MLcons] *)
-let check_and_generalize (r0,l,c) =
+let check_function_branch (r,l,c) =
let nargs = List.length l in
let rec genrec n = function
| MLrel i as c ->
@@ -660,58 +676,85 @@ let check_and_generalize (r0,l,c) =
if i'<1 then c
else if i'>nargs then MLrel (i-nargs+1)
else raise Impossible
- | MLcons(_,r,args) when r=r0 && (test_eta_args_lift n nargs args) ->
+ | MLcons(_,r',args) when r=r' && (test_eta_args_lift n nargs args) ->
MLrel (n+1)
| a -> ast_map_lift genrec n a
in genrec 0 c
+(*s [check_constant_branch (r,l,c)] checks if branch [c] is independent
+ from the pattern [MLcons(r,l)]. For that is raises [Impossible] if any
+ variable in [l] occurs in [c], and otherwise returns [c] lifted to
+ appear like a function with one arg (for uniformity with the
+ branch-as-function optimization) *)
+
+let check_constant_branch (_,l,c) =
+ let n = List.length l in
+ if ast_occurs_itvl 1 n c then raise Impossible;
+ ast_lift (1-n) c
+
+(* The following structure allows to record which element occurred
+ at what position, and then finally return the most frequent
+ element and its positions. *)
+
+let census_add, census_max, census_clean =
+ let h = Hashtbl.create 13 in
+ let clear () = Hashtbl.clear h in
+ let add e i =
+ let l = try Hashtbl.find h e with Not_found -> [] in
+ Hashtbl.replace h e (i::l)
+ in
+ let max e0 =
+ let len = ref 0 and lst = ref [] and elm = ref e0 in
+ Hashtbl.iter
+ (fun e l ->
+ let n = List.length l in
+ if n > !len then begin len := n; lst := l; elm := e end)
+ h;
+ (!elm,!lst)
+ in
+ (add,max,clear)
+
+(* Given an abstraction function [abstr] (one of [check_*_branch]),
+ return the longest possible list of branches that have the
+ same abstraction, along with this abstraction. *)
+
+let factor_branches abstr br =
+ census_clean ();
+ for i = 0 to Array.length br - 1 do
+ try census_add (abstr br.(i)) i with Impossible -> ()
+ done;
+ let br_factor, br_list = census_max MLdummy in
+ if br_list = [] then None
+ else if Array.length br >= 2 && List.length br_list < 2 then None
+ else Some (br_factor, br_list)
+
(*s [check_generalizable_case] checks if all branches can be seen as the
same function [f] applied to the term matched. It is a generalized version
- of the identity case optimization. *)
+ of both the identity case optimization and the constant case optimisation
+ ([f] can be a constant function) *)
-(* CAVEAT: this optimization breaks typing in some special case. example:
- [type 'x a = A]. Then [let f = function A -> A] has type ['x a -> 'y a],
+(* The optimisation [factor_branches check_function_branch] breaks types
+ in some special case. Example: [type 'x a = A].
+ Then [let f = function A -> A] has type ['x a -> 'y a],
which is incompatible with the type of [let f x = x].
- By default, we brutally disable this optim except for the known types
- in theories/Init/*.v *)
-
-let generalizable_mind m =
- match mind_modpath m with
- | MPfile dir -> is_dirpath_prefix_of (dirpath_of_string "Coq.Init") dir
- | _ -> false
-
-let check_generalizable_case unsafe br =
- if not unsafe then
- (match br.(0) with
- | ConstructRef ((kn,_),_), _, _ ->
- if not (generalizable_mind kn) then raise Impossible
- | _ -> assert false);
- let f = check_and_generalize br.(0) in
- for i = 1 to Array.length br - 1 do
- if check_and_generalize br.(i) <> f then raise Impossible
- done; f
-
-(*s Detecting similar branches of a match *)
-
-(* If several branches of a match are equal (and independent from their
- patterns) we will print them using a _ pattern. If _all_ branches
- are equal, we remove the match.
-*)
+ We check first that there isn't such phantom variable in the inductive type
+ we're considering. *)
-let common_branches br =
- let tab = Hashtbl.create 13 in
- for i = 0 to Array.length br - 1 do
- let (r,ids,t) = br.(i) in
- let n = List.length ids in
- if not (ast_occurs_itvl 1 n t) then
- let t = ast_lift (-n) t in
- let l = try Hashtbl.find tab t with Not_found -> [] in
- Hashtbl.replace tab t (i::l)
- done;
- let best = ref [] in
- Hashtbl.iter
- (fun _ l -> if List.length l > List.length !best then best := l) tab;
- if Array.length br >= 2 && List.length !best < 2 then [] else !best
+let check_optim_id br =
+ let (kn,i) =
+ match br.(0) with (ConstructRef (c,_),_,_) -> c | _ -> assert false
+ in
+ let ip = (snd (lookup_ind kn)).ind_packets.(i) in
+ match ip.ip_optim_id_ok with
+ | Some ok -> ok
+ | None ->
+ let tvars =
+ intset_union_map_array (intset_union_map_list type_listvar)
+ ip.ip_types
+ in
+ let ok = (Intset.cardinal tvars = List.length ip.ip_vars) in
+ ip.ip_optim_id_ok <- Some ok;
+ ok
(*s If all branches are functions, try to permut the case and the functions. *)
@@ -846,25 +889,33 @@ and simpl_case o i br e =
if o.opt_case_iot && (is_iota_gen e) then (* Generalized iota-redex *)
simpl o (iota_gen br e)
else
- try (* Does a term [f] exist such that each branch is [(f e)] ? *)
- if not o.opt_case_idr then raise Impossible;
- let f = check_generalizable_case o.opt_case_idg br in
- simpl o (MLapp (MLlam (anonymous,f),[e]))
- with Impossible ->
- (* Detect common branches *)
- let common_br = if not o.opt_case_cst then [] else common_branches br in
- if List.length common_br = Array.length br then
- let (_,ids,t) = br.(0) in ast_lift (-List.length ids) t
- else
- let new_i = (fst i, common_br) in
- (* Swap the case and the lam if possible *)
- if o.opt_case_fun
- then
- let ids,br = permut_case_fun br [] in
- let n = List.length ids in
- if n <> 0 then named_lams ids (MLcase (new_i,ast_lift n e, br))
- else MLcase (new_i,e,br)
- else MLcase (new_i,e,br)
+ (* Swap the case and the lam if possible *)
+ let ids,br = if o.opt_case_fun then permut_case_fun br [] else [],br in
+ let n = List.length ids in
+ if n <> 0 then
+ simpl o (named_lams ids (MLcase (i,ast_lift n e, br)))
+ else
+ (* Does a term [f] exist such that many branches are [(f e)] ? *)
+ let opt1 =
+ if o.opt_case_idr && (o.opt_case_idg || check_optim_id br) then
+ factor_branches check_function_branch br
+ else None
+ in
+ (* Detect common constant branches. Often a particular case of
+ branch-as-function optim, but not always (e.g. A->A|B->A) *)
+ let opt2 =
+ if opt1 = None && o.opt_case_cst then
+ factor_branches check_constant_branch br
+ else opt1
+ in
+ match opt2 with
+ | Some (f,ints) when List.length ints = Array.length br ->
+ (* if all branches have been factorized, we remove the match *)
+ simpl o (MLletin (Tmp anonymous_name, e, f))
+ | Some (f,ints) ->
+ let ci = if ast_occurs 1 f then BranchFun ints else BranchCst ints
+ in MLcase ((fst i,ci), e, br)
+ | None -> MLcase (i, e, br)
(*S Local prop elimination. *)
(* We try to eliminate as many [prop] as possible inside an [ml_ast]. *)
diff --git a/plugins/extraction/mlutil.mli b/plugins/extraction/mlutil.mli
index 3466f22d3..440684cdd 100644
--- a/plugins/extraction/mlutil.mli
+++ b/plugins/extraction/mlutil.mli
@@ -112,6 +112,10 @@ val normalize : ml_ast -> ml_ast
val optimize_fix : ml_ast -> ml_ast
val inline : global_reference -> ml_ast -> bool
+exception Impossible
+val check_function_branch : ml_branch -> ml_ast
+val check_constant_branch : ml_branch -> ml_ast
+
(* Classification of signatures *)
type sign_kind =
diff --git a/plugins/extraction/ocaml.ml b/plugins/extraction/ocaml.ml
index 1ec08b3ec..3dd15f6e7 100644
--- a/plugins/extraction/ocaml.ml
+++ b/plugins/extraction/ocaml.ml
@@ -339,19 +339,29 @@ and pp_one_pat env i (r,ids,t) =
expr
and pp_pat env (info,factors) pv =
+ let factor_br, factor_l = try match factors with
+ | BranchFun (i::_ as l) -> check_function_branch pv.(i), l
+ | BranchCst (i::_ as l) -> ast_pop (check_constant_branch pv.(i)), l
+ | _ -> MLdummy, []
+ with Impossible -> MLdummy, []
+ in
+ let par = expr_needs_par factor_br in
+ let last = Array.length pv - 1 in
prvecti
- (fun i x -> if List.mem i factors then mt () else
+ (fun i x -> if List.mem i factor_l then mt () else
let s1,s2 = pp_one_pat env info x in
hov 2 (s1 ++ str " ->" ++ spc () ++ s2) ++
- (if factors = [] && i = Array.length pv-1 then mt ()
- else fnl () ++ str " | ")) pv
+ if i = last && factor_l = [] then mt () else
+ fnl () ++ str " | ") pv
++
- match factors with
- | [] -> mt ()
- | i::_ ->
- let (_,ids,t) = pv.(i) in
- let t = ast_lift (-List.length ids) t in
- hov 2 (str "_ ->" ++ spc () ++ pp_expr (expr_needs_par t) env [] t)
+ if factor_l = [] then mt () else match factors with
+ | BranchFun _ ->
+ let ids, env' = push_vars [anonymous_name] env in
+ hov 2 (pr_id (List.hd ids) ++ str " ->" ++ spc () ++
+ pp_expr par env' [] factor_br)
+ | BranchCst _ ->
+ hov 2 (str "_ ->" ++ spc () ++ pp_expr par env [] factor_br)
+ | BranchNone -> mt ()
and pp_function env t =
let bl,t' = collect_lams t in