From 44e80ba4e4c863e0c38cc7cf6a65579640385436 Mon Sep 17 00:00:00 2001 From: letouzey Date: Thu, 8 Jul 2010 14:07:48 +0000 Subject: Extraction: more factorization of common match branches In addition to the "| _ -> cst" situation, now we can also reconstruct a "| e -> f e" final branch. For instance, this has a tremenduous effect on Compcert/backend/Selection.v. NB: The "fun" factorisation is almost more general than the "cst" situation, but not always. Think of A=>A|B=>A, 1st branch will be recognized as (fun x->x), not (fun x->A). We also add a fine detection of inductive types with phantom type variables, for which this optimisation is type-unsafe. git-svn-id: svn+ssh://scm.gforge.inria.fr/svn/coq/trunk@13267 85f007b7-540e-0410-9357-904b9bb8a0f7 --- plugins/extraction/extraction.ml | 5 +- plugins/extraction/haskell.ml | 28 ++++-- plugins/extraction/miniml.mli | 13 ++- plugins/extraction/mlutil.ml | 183 +++++++++++++++++++++++++-------------- plugins/extraction/mlutil.mli | 4 + plugins/extraction/ocaml.ml | 28 ++++-- 6 files changed, 172 insertions(+), 89 deletions(-) (limited to 'plugins/extraction') diff --git a/plugins/extraction/extraction.ml b/plugins/extraction/extraction.ml index c0097d492..a4ec24ba2 100644 --- a/plugins/extraction/extraction.ml +++ b/plugins/extraction/extraction.ml @@ -360,7 +360,8 @@ and extract_ind env kn = (* kn is supposed to be in long form *) ip_logical = (not b); ip_sign = s; ip_vars = v; - ip_types = t }) + ip_types = t; + ip_optim_id_ok = None }) mib.mind_packets in @@ -787,7 +788,7 @@ and extract_case env mle ((kn,i) as ip,c,br) mlt = end else (* Standard case: we apply [extract_branch]. *) - MLcase ((mi.ind_info,[]), a, Array.init br_size extract_branch) + MLcase ((mi.ind_info,BranchNone), a, Array.init br_size extract_branch) (*s Extraction of a (co)-fixpoint. *) diff --git a/plugins/extraction/haskell.ml b/plugins/extraction/haskell.ml index d42840613..97f49d833 100644 --- a/plugins/extraction/haskell.ml +++ b/plugins/extraction/haskell.ml @@ -182,18 +182,28 @@ and pp_pat env factors pv = (fun () -> (spc ())) pr_id (List.rev ids))) ++ str " ->" ++ spc () ++ pp_expr par env' [] t) in + let factor_br, factor_l = try match factors with + | BranchFun (i::_ as l) -> check_function_branch pv.(i), l + | BranchCst (i::_ as l) -> ast_pop (check_constant_branch pv.(i)), l + | _ -> MLdummy, [] + with Impossible -> MLdummy, [] + in + let par = expr_needs_par factor_br in + let last = Array.length pv - 1 in prvecti - (fun i x -> if List.mem i factors then mt () else + (fun i x -> if List.mem i factor_l then mt () else (pp_one_pat pv.(i) ++ - if factors = [] && i = Array.length pv - 1 then mt () - else fnl () ++ str " ")) pv + if i = last && factor_l = [] then mt () else + fnl () ++ str " ")) pv ++ - match factors with - | [] -> mt () - | i::_ -> - let (_,ids,t) = pv.(i) in - let t = ast_lift (-List.length ids) t in - hov 2 (str "_ ->" ++ spc () ++ pp_expr (expr_needs_par t) env [] t) + if factor_l = [] then mt () else match factors with + | BranchFun _ -> + let ids, env' = push_vars [anonymous_name] env in + pr_id (List.hd ids) ++ str " ->" ++ spc () ++ + pp_expr par env' [] factor_br + | BranchCst _ -> + str "_ ->" ++ spc () ++ pp_expr par env [] factor_br + | BranchNone -> mt () (*s names of the functions ([ids]) are already pushed in [env], and passed here just for convenience. *) diff --git a/plugins/extraction/miniml.mli b/plugins/extraction/miniml.mli index a27a9cf03..410054624 100644 --- a/plugins/extraction/miniml.mli +++ b/plugins/extraction/miniml.mli @@ -57,8 +57,6 @@ type inductive_info = | Standard | Record of global_reference list -type case_info = int list (* list of branches to merge in a _ pattern *) - (* A [ml_ind_packet] is the miniml counterpart of a [one_inductive_body]. If the inductive is logical ([ip_logical = false]), then all other fields are unused. Otherwise, @@ -73,7 +71,9 @@ type ml_ind_packet = { ip_logical : bool; ip_sign : signature; ip_vars : identifier list; - ip_types : (ml_type list) array } + ip_types : (ml_type list) array; + mutable ip_optim_id_ok : bool option +} (* [ip_nparams] contains the number of parameters. *) @@ -96,6 +96,13 @@ type ml_ident = | Id of identifier | Tmp of identifier +(* list of branches to merge in a common pattern *) + +type case_info = + | BranchNone + | BranchFun of int list + | BranchCst of int list + type ml_branch = global_reference * ml_ident list * ml_ast and ml_ast = diff --git a/plugins/extraction/mlutil.ml b/plugins/extraction/mlutil.ml index 1cd226616..6a5a83b1d 100644 --- a/plugins/extraction/mlutil.ml +++ b/plugins/extraction/mlutil.ml @@ -236,6 +236,21 @@ let type_maxvar t = | _ -> n in parse 0 t +(*s What are the type variables occurring in [t]. *) + +let intset_union_map_list f l = + List.fold_left (fun s t -> Intset.union s (f t)) Intset.empty l + +let intset_union_map_array f a = + Array.fold_left (fun s t -> Intset.union s (f t)) Intset.empty a + +let rec type_listvar = function + | Tmeta {contents = Some t} -> type_listvar t + | Tvar i | Tvar' i -> Intset.singleton i + | Tarr (a,b) -> Intset.union (type_listvar a) (type_listvar b) + | Tglob (_,l) -> intset_union_map_list type_listvar l + | _ -> Intset.empty + (*s From [a -> b -> c] to [[a;b],c]. *) let rec type_decomp = function @@ -648,11 +663,12 @@ let rec ast_glob_subst s t = match t with (*S Auxiliary functions used in simplification of ML cases. *) -(*s [check_and_generalize (r0,l,c)] transforms any [MLcons(r0,l)] in [MLrel 1] - and raises [Impossible] if any variable in [l] occurs outside such a - [MLcons] *) +(*s [check_function_branch (r,l,c)] checks if branch [c] can be seen + as a function [f] applied to [MLcons(r,l)]. For that it transforms + any [MLcons(r,l)] in [MLrel 1] and raises [Impossible] if any + variable in [l] occurs outside such a [MLcons] *) -let check_and_generalize (r0,l,c) = +let check_function_branch (r,l,c) = let nargs = List.length l in let rec genrec n = function | MLrel i as c -> @@ -660,58 +676,85 @@ let check_and_generalize (r0,l,c) = if i'<1 then c else if i'>nargs then MLrel (i-nargs+1) else raise Impossible - | MLcons(_,r,args) when r=r0 && (test_eta_args_lift n nargs args) -> + | MLcons(_,r',args) when r=r' && (test_eta_args_lift n nargs args) -> MLrel (n+1) | a -> ast_map_lift genrec n a in genrec 0 c +(*s [check_constant_branch (r,l,c)] checks if branch [c] is independent + from the pattern [MLcons(r,l)]. For that is raises [Impossible] if any + variable in [l] occurs in [c], and otherwise returns [c] lifted to + appear like a function with one arg (for uniformity with the + branch-as-function optimization) *) + +let check_constant_branch (_,l,c) = + let n = List.length l in + if ast_occurs_itvl 1 n c then raise Impossible; + ast_lift (1-n) c + +(* The following structure allows to record which element occurred + at what position, and then finally return the most frequent + element and its positions. *) + +let census_add, census_max, census_clean = + let h = Hashtbl.create 13 in + let clear () = Hashtbl.clear h in + let add e i = + let l = try Hashtbl.find h e with Not_found -> [] in + Hashtbl.replace h e (i::l) + in + let max e0 = + let len = ref 0 and lst = ref [] and elm = ref e0 in + Hashtbl.iter + (fun e l -> + let n = List.length l in + if n > !len then begin len := n; lst := l; elm := e end) + h; + (!elm,!lst) + in + (add,max,clear) + +(* Given an abstraction function [abstr] (one of [check_*_branch]), + return the longest possible list of branches that have the + same abstraction, along with this abstraction. *) + +let factor_branches abstr br = + census_clean (); + for i = 0 to Array.length br - 1 do + try census_add (abstr br.(i)) i with Impossible -> () + done; + let br_factor, br_list = census_max MLdummy in + if br_list = [] then None + else if Array.length br >= 2 && List.length br_list < 2 then None + else Some (br_factor, br_list) + (*s [check_generalizable_case] checks if all branches can be seen as the same function [f] applied to the term matched. It is a generalized version - of the identity case optimization. *) + of both the identity case optimization and the constant case optimisation + ([f] can be a constant function) *) -(* CAVEAT: this optimization breaks typing in some special case. example: - [type 'x a = A]. Then [let f = function A -> A] has type ['x a -> 'y a], +(* The optimisation [factor_branches check_function_branch] breaks types + in some special case. Example: [type 'x a = A]. + Then [let f = function A -> A] has type ['x a -> 'y a], which is incompatible with the type of [let f x = x]. - By default, we brutally disable this optim except for the known types - in theories/Init/*.v *) - -let generalizable_mind m = - match mind_modpath m with - | MPfile dir -> is_dirpath_prefix_of (dirpath_of_string "Coq.Init") dir - | _ -> false - -let check_generalizable_case unsafe br = - if not unsafe then - (match br.(0) with - | ConstructRef ((kn,_),_), _, _ -> - if not (generalizable_mind kn) then raise Impossible - | _ -> assert false); - let f = check_and_generalize br.(0) in - for i = 1 to Array.length br - 1 do - if check_and_generalize br.(i) <> f then raise Impossible - done; f - -(*s Detecting similar branches of a match *) - -(* If several branches of a match are equal (and independent from their - patterns) we will print them using a _ pattern. If _all_ branches - are equal, we remove the match. -*) + We check first that there isn't such phantom variable in the inductive type + we're considering. *) -let common_branches br = - let tab = Hashtbl.create 13 in - for i = 0 to Array.length br - 1 do - let (r,ids,t) = br.(i) in - let n = List.length ids in - if not (ast_occurs_itvl 1 n t) then - let t = ast_lift (-n) t in - let l = try Hashtbl.find tab t with Not_found -> [] in - Hashtbl.replace tab t (i::l) - done; - let best = ref [] in - Hashtbl.iter - (fun _ l -> if List.length l > List.length !best then best := l) tab; - if Array.length br >= 2 && List.length !best < 2 then [] else !best +let check_optim_id br = + let (kn,i) = + match br.(0) with (ConstructRef (c,_),_,_) -> c | _ -> assert false + in + let ip = (snd (lookup_ind kn)).ind_packets.(i) in + match ip.ip_optim_id_ok with + | Some ok -> ok + | None -> + let tvars = + intset_union_map_array (intset_union_map_list type_listvar) + ip.ip_types + in + let ok = (Intset.cardinal tvars = List.length ip.ip_vars) in + ip.ip_optim_id_ok <- Some ok; + ok (*s If all branches are functions, try to permut the case and the functions. *) @@ -846,25 +889,33 @@ and simpl_case o i br e = if o.opt_case_iot && (is_iota_gen e) then (* Generalized iota-redex *) simpl o (iota_gen br e) else - try (* Does a term [f] exist such that each branch is [(f e)] ? *) - if not o.opt_case_idr then raise Impossible; - let f = check_generalizable_case o.opt_case_idg br in - simpl o (MLapp (MLlam (anonymous,f),[e])) - with Impossible -> - (* Detect common branches *) - let common_br = if not o.opt_case_cst then [] else common_branches br in - if List.length common_br = Array.length br then - let (_,ids,t) = br.(0) in ast_lift (-List.length ids) t - else - let new_i = (fst i, common_br) in - (* Swap the case and the lam if possible *) - if o.opt_case_fun - then - let ids,br = permut_case_fun br [] in - let n = List.length ids in - if n <> 0 then named_lams ids (MLcase (new_i,ast_lift n e, br)) - else MLcase (new_i,e,br) - else MLcase (new_i,e,br) + (* Swap the case and the lam if possible *) + let ids,br = if o.opt_case_fun then permut_case_fun br [] else [],br in + let n = List.length ids in + if n <> 0 then + simpl o (named_lams ids (MLcase (i,ast_lift n e, br))) + else + (* Does a term [f] exist such that many branches are [(f e)] ? *) + let opt1 = + if o.opt_case_idr && (o.opt_case_idg || check_optim_id br) then + factor_branches check_function_branch br + else None + in + (* Detect common constant branches. Often a particular case of + branch-as-function optim, but not always (e.g. A->A|B->A) *) + let opt2 = + if opt1 = None && o.opt_case_cst then + factor_branches check_constant_branch br + else opt1 + in + match opt2 with + | Some (f,ints) when List.length ints = Array.length br -> + (* if all branches have been factorized, we remove the match *) + simpl o (MLletin (Tmp anonymous_name, e, f)) + | Some (f,ints) -> + let ci = if ast_occurs 1 f then BranchFun ints else BranchCst ints + in MLcase ((fst i,ci), e, br) + | None -> MLcase (i, e, br) (*S Local prop elimination. *) (* We try to eliminate as many [prop] as possible inside an [ml_ast]. *) diff --git a/plugins/extraction/mlutil.mli b/plugins/extraction/mlutil.mli index 3466f22d3..440684cdd 100644 --- a/plugins/extraction/mlutil.mli +++ b/plugins/extraction/mlutil.mli @@ -112,6 +112,10 @@ val normalize : ml_ast -> ml_ast val optimize_fix : ml_ast -> ml_ast val inline : global_reference -> ml_ast -> bool +exception Impossible +val check_function_branch : ml_branch -> ml_ast +val check_constant_branch : ml_branch -> ml_ast + (* Classification of signatures *) type sign_kind = diff --git a/plugins/extraction/ocaml.ml b/plugins/extraction/ocaml.ml index 1ec08b3ec..3dd15f6e7 100644 --- a/plugins/extraction/ocaml.ml +++ b/plugins/extraction/ocaml.ml @@ -339,19 +339,29 @@ and pp_one_pat env i (r,ids,t) = expr and pp_pat env (info,factors) pv = + let factor_br, factor_l = try match factors with + | BranchFun (i::_ as l) -> check_function_branch pv.(i), l + | BranchCst (i::_ as l) -> ast_pop (check_constant_branch pv.(i)), l + | _ -> MLdummy, [] + with Impossible -> MLdummy, [] + in + let par = expr_needs_par factor_br in + let last = Array.length pv - 1 in prvecti - (fun i x -> if List.mem i factors then mt () else + (fun i x -> if List.mem i factor_l then mt () else let s1,s2 = pp_one_pat env info x in hov 2 (s1 ++ str " ->" ++ spc () ++ s2) ++ - (if factors = [] && i = Array.length pv-1 then mt () - else fnl () ++ str " | ")) pv + if i = last && factor_l = [] then mt () else + fnl () ++ str " | ") pv ++ - match factors with - | [] -> mt () - | i::_ -> - let (_,ids,t) = pv.(i) in - let t = ast_lift (-List.length ids) t in - hov 2 (str "_ ->" ++ spc () ++ pp_expr (expr_needs_par t) env [] t) + if factor_l = [] then mt () else match factors with + | BranchFun _ -> + let ids, env' = push_vars [anonymous_name] env in + hov 2 (pr_id (List.hd ids) ++ str " ->" ++ spc () ++ + pp_expr par env' [] factor_br) + | BranchCst _ -> + hov 2 (str "_ ->" ++ spc () ++ pp_expr par env [] factor_br) + | BranchNone -> mt () and pp_function env t = let bl,t' = collect_lams t in -- cgit v1.2.3