From 8917ab003a9b7f2abf8e399b5e7ad013b31a2e0e Mon Sep 17 00:00:00 2001 From: Stephane Glondu Date: Thu, 20 Sep 2012 09:41:14 +0200 Subject: Imported Upstream version 0.3 --- matcher.ml | 1293 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1293 insertions(+) create mode 100644 matcher.ml (limited to 'matcher.ml') diff --git a/matcher.ml b/matcher.ml new file mode 100644 index 0000000..55a64b7 --- /dev/null +++ b/matcher.ml @@ -0,0 +1,1293 @@ +(***************************************************************************) +(* This is part of aac_tactics, it is distributed under the terms of the *) +(* GNU Lesser General Public License version 3 *) +(* (see file LICENSE for more details) *) +(* *) +(* Copyright 2009-2010: Thomas Braibant, Damien Pous. *) +(***************************************************************************) + +(** This module defines our matching functions, modulo associativity + and commutativity (AAC). + + The basic idea is to find a substitution [env] such that the + pattern [p] instantiated by [env] is equal to [t] modulo AAC. + + We proceed by structural decomposition of the pattern, and try all + possible non-deterministic split of the subject when needed. The + function {!matcher} is limited to top-level matching, that is, the + subject must make a perfect match against the pattern ([x+x] do + not match [a+a+b] ). We use a search monad {!Search} to perform + non-deterministic splits in an almost transparent way. We also + provide a function {!subterm} for finding a match that is a + subterm modulo AAC of the subject. Therefore, we are able to solve + the aforementioned case [x+x] against [a+b+a]. + + This file is structured as follows. First, we define the search + monad. Then,we define the two representations of terms (one + representing the AST, and one in normal form ), and environments + from variables to terms. Then, we use these parts to solve + matching problem. Finally, we wrap this function {!matcher} into + {!subterm} +*) + + +module Control = +struct + let debug = false + let time = false + let printing = false +end + +module Debug = Helper.Debug (Control) +open Debug + +module Search = Search_monad (* a handle *) + +type symbol = int +type var = int +type units = (symbol * symbol) list (* from AC/A symbols to the unit *) +type ext_units = + { + unit_for : units; + is_ac : (symbol * bool) list + } + +let print_units units= + List.iter (fun (op,unit) -> Printf.printf "%i %i\t" op unit) units; + Printf.printf "\n%!" + +exception NoUnit + +let get_unit units op = + try List.assoc op units + with + | Not_found -> raise NoUnit + +let is_unit units op unit = List.mem (op,unit) units + + +open Search + +type 'a mset = ('a * int) list +let linear p = + let rec ncons t l = function + | 0 -> l + | n -> t::ncons t l (n-1) + in + let rec aux = function + [ ] -> [] + | (t,n)::q -> let q = aux q in + ncons t q n + in aux p + + + +(** The module {!Terms} defines two different types for expressions. + + - a public type {!Terms.t} that represent abstract syntax trees of + expressions with binary associative (and commutative) operators + + - a private type {!Terms.nf_term} that represent an equivalence + class for terms that are equal modulo AAC. The constructions + functions on this type ensure the property that the term is in + normal form (that is, no sum can appear as a subterm of the same + sum, no trailing units, etc...). + +*) + +module Terms : sig + + (** {1 Abstract syntax tree of terms} + + Terms represented using this datatype are representation of the + AST of an expression. *) + + type t = + Dot of (symbol * t * t) + | Plus of (symbol * t * t) + | Sym of (symbol * t array) + | Var of var + | Unit of symbol + + val equal_aac : units -> t -> t -> bool + val size: t -> int + + (** {1 Terms in normal form} + + A term in normal form is the canonical representative of the + equivalence class of all the terms that are equal modulo + Associativity and Commutativity. Outside the {!Matcher} module, + one does not need to access the actual representation of this + type. *) + + + type nf_term = private + | TAC of symbol * nf_term mset + | TA of symbol * nf_term list + | TSym of symbol * nf_term list + | TUnit of symbol + | TVar of var + + (** {2 Fresh variables: we provide some functions to pick some fresh variables with respect to a term} *) + val fresh_var_term : t -> int + val fresh_var_nfterm : nf_term -> int + + + (** {2 Constructors: we ensure that the terms are always + normalised + + braibant - Fri 27 Aug 2010, 15:11 + Moreover, we assure that we will not build empty sums or empty + products, leaving this task to the caller function. + } + *) + val mk_TAC : units -> symbol -> nf_term mset -> nf_term + val mk_TA : units -> symbol -> nf_term list -> nf_term + val mk_TSym : symbol -> nf_term list -> nf_term + val mk_TVar : var -> nf_term + val mk_TUnit : symbol -> nf_term + + (** {2 Comparisons} *) + val nf_term_compare : nf_term -> nf_term -> int + val nf_equal : nf_term -> nf_term -> bool + + (** {2 Printing function} *) + val sprint_nf_term : nf_term -> string + + (** {2 Conversion functions} *) + val term_of_t : units -> t -> nf_term + val t_of_term : nf_term -> t (* we could return the units here *) +end + = struct + + type t = + Dot of (symbol * t * t) + | Plus of (symbol * t * t) + | Sym of (symbol * t array) + | Var of var + | Unit of symbol + + let rec size = function + | Dot (_,x,y) + | Plus (_,x,y) -> size x+ size y + 1 + | Sym (_,v)-> Array.fold_left (fun acc x -> size x + acc) 1 v + | _ -> 1 + + + type nf_term = + | TAC of symbol * nf_term mset + | TA of symbol * nf_term list + | TSym of symbol * nf_term list + | TUnit of symbol + | TVar of var + + + (** {2 Picking fresh variables} *) + + (** [fresh_var_term] picks a fresh variable with respect to a term *) + let fresh_var_term t = + let rec aux = function + | Dot (_,t1,t2) + | Plus (_,t1,t2) -> max (aux t1) (aux t2) + | Sym (_,v) -> Array.fold_left (fun acc x -> max acc (aux x)) 0 v + | Var v -> assert (v >= 0); v + | Unit _ -> 0 + in + aux t + + (** [fresh_var_nfterm] picks a fresh_variable with respect to a term *) + let fresh_var_nfterm t = + let rec aux = function + | TAC (_,l) -> List.fold_left (fun acc (x,_) -> max acc (aux x)) 0 l + | TA (_,l) + | TSym (_,l) -> List.fold_left (fun acc x -> max acc (aux x)) 0 l + | TVar v -> assert (v >= 0); v + | TUnit _ -> 0 + in + aux t + + (** {2 Constructors: we ensure that the terms are always + normalised} *) + + (** {3 Pre constructors : These constructors ensure that sums and + products are not degenerated (no trailing units)} *) + let mk_TAC' units (s : symbol) l = match l with + | [] -> TUnit (get_unit units s ) + | [_,0] -> assert false + | [t,1] -> t + | _ -> TAC (s,l) + let mk_TA' units (s : symbol) l = match l with + | [] -> TUnit (get_unit units s ) + | [t] -> t + | _ -> TA (s,l) + + + (** {2 Comparison} *) + + let nf_term_compare = Pervasives.compare + let nf_equal a b = + a = b + + (** [merge_ac comp l1 l2] merges two lists of terms with coefficients + into one. Terms that are equal modulo the comparison function + [comp] will see their arities added. *) + + (* This function could be improved by the use of sorted msets *) + let merge_ac (compare : 'a -> 'a -> int) (l : 'a mset) (l' : 'a mset) : 'a mset = + let rec aux l l'= + match l,l' with + | [], _ -> l' + | _, [] -> l + | (t,tar)::q, (t',tar')::q' -> + begin match compare t t' with + | 0 -> ( t,tar+tar'):: aux q q' + | -1 -> (t, tar):: aux q l' + | _ -> (t', tar'):: aux l q' + end + in aux l l' + + (** [merge_map f l] uses the combinator [f] to combine the head of the + list [l] with the merge_maped tail of [l] *) + let rec merge_map (f : 'a -> 'b list -> 'b list) (l : 'a list) : 'b list = + match l with + | [] -> [] + | t::q -> f t (merge_map f q) + + + (** This function has to deal with the arities *) + let rec merge (l : nf_term mset) (l' : nf_term mset) : nf_term mset= + merge_ac nf_term_compare l l' + + let extract_A units s t = + match t with + | TA (s',l) when s' = s -> l + | TUnit u when is_unit units s u -> [] + | _ -> [t] + + let extract_AC units s (t,ar) : nf_term mset = + match t with + | TAC (s',l) when s' = s -> List.map (fun (x,y) -> (x,y*ar)) l + | TUnit u when is_unit units s u -> [] + | _ -> [t,ar] + + + (** {3 Constructors of {!nf_term}}*) + let mk_TAC units (s : symbol) (l : (nf_term *int) list) = + mk_TAC' units s + (merge_map (fun u l -> merge (extract_AC units s u) l) l) + let mk_TA units s l = + mk_TA' units s + (merge_map (fun u l -> (extract_A units s u) @ l) l) + let mk_TSym s l = TSym (s,l) + let mk_TVar v = TVar v + let mk_TUnit s = TUnit s + + + (** {2 Printing function} *) + let print_binary_list (single : 'a -> string) + (unit : string) + (binary : string -> string -> string) (l : 'a list) = + let rec aux l = + match l with + [] -> unit + | [t] -> single t + | t::q -> + let r = aux q in + Printf.sprintf "%s" (binary (single t) r) + in + aux l + + let print_symbol s = + match s with + | s, None -> Printf.sprintf "%i" s + | s , Some u -> Printf.sprintf "%i(unit %i)" s u + + let sprint_ac (single : 'a -> string) (l : 'a mset) = + (print_binary_list + (fun (x,t) -> + if t = 1 + then single x + else Printf.sprintf "%i*%s" t (single x) + ) + "0" + (fun x y -> x ^ " , " ^ y) + l + ) + + let print_symbol single s l = + match l with + [] -> Printf.sprintf "%i" s + | _ -> + Printf.sprintf "%i(%s)" + s + (print_binary_list single "" (fun x y -> x ^ "," ^ y) l) + + + let print_ac single s l = + Printf.sprintf "[%s:AC]{%s}" + (string_of_int s ) + (sprint_ac + single + l + ) + + let print_a single s l = + Printf.sprintf "[%s:A]{%s}" + (string_of_int s) + (print_binary_list single "1" (fun x y -> x ^ " , " ^ y) l) + + let rec sprint_nf_term = function + | TSym (s,l) -> print_symbol sprint_nf_term s l + | TAC (s,l) -> print_ac sprint_nf_term s l + | TA (s,l) -> print_a sprint_nf_term s l + | TVar v -> Printf.sprintf "x%i" v + | TUnit s -> Printf.sprintf "unit%i" s + + + + + (** {2 Conversion functions} *) + + (* rebuilds a tree out of a list, under the assumption that the list is not empty *) + let rec binary_of_list f comb l = + let l = List.rev l in + let rec aux = function + | [] -> assert false + | [t] -> f t + | t::q -> comb (aux q) (f t) + in + aux l + + let term_of_t units : t -> nf_term = + let rec term_of_t = function + | Dot (s,l,r) -> + let l = term_of_t l in + let r = term_of_t r in + mk_TA units s [l;r] + | Plus (s,l,r) -> + let l = term_of_t l in + let r = term_of_t r in + mk_TAC units s [l,1;r,1] + | Unit x -> + mk_TUnit x + | Sym (s,t) -> + let t = Array.to_list t in + let t = List.map term_of_t t in + mk_TSym s t + | Var i -> + mk_TVar ( i) + in + term_of_t + + let rec t_of_term : nf_term -> t = function + | TAC (s,l) -> + (binary_of_list + t_of_term + (fun l r -> Plus ( s,l,r)) + (linear l) + ) + | TA (s,l) -> + (binary_of_list + t_of_term + (fun l r -> Dot ( s,l,r)) + l + ) + | TSym (s,l) -> + let v = Array.of_list l in + let v = Array.map (t_of_term) v in + Sym ( s,v) + | TVar x -> Var x + | TUnit s -> Unit s + + + let equal_aac units x y = + nf_equal (term_of_t units x) (term_of_t units y) + end + + (** Terms environments defined as association lists from variables to + terms in normal form {! Terms.nf_term} *) + module Subst : sig + type t + + val find : t -> var -> Terms.nf_term option + val add : t -> var -> Terms.nf_term -> t + val empty : t + val instantiate : t -> Terms.t -> Terms.t + val sprint : t -> string + val to_list : t -> (var*Terms.t) list + end + = + struct + open Terms + + (** Terms environments, with nf_terms, to avoid costly conversions + of {!Terms.nf_terms} to {!Terms.t}, that will be mostly discarded*) + type t = (var * nf_term) list + + let find : t -> var -> nf_term option = fun t x -> + try Some (List.assoc x t) with | _ -> None + let add t x v = (x,v) :: t + let empty = [] + + let sprint (l : t) = + match l with + | [] -> Printf.sprintf "Empty environment\n" + | _ -> + let s = List.fold_left + (fun acc (x,y) -> + Printf.sprintf "%sX%i -> %s\n" + acc + x + (sprint_nf_term y) + ) + "" + (List.rev l) in + Printf.sprintf "%s\n%!" s + + + + (** [instantiate] is an homomorphism except for the variables*) + let instantiate (t: t) (x:Terms.t) : Terms.t = + let rec aux = function + | Unit _ as x -> x + | Sym (s,t) -> Sym (s,Array.map aux t) + | Plus (s,l,r) -> Plus (s, aux l, aux r) + | Dot (s,l,r) -> Dot (s, aux l, aux r) + | Var i -> + begin match find t i with + | None -> Util.error "aac_tactics: instantiate failure" + | Some x -> t_of_term x + end + in aux x + + let to_list t = List.map (fun (x,y) -> x,Terms.t_of_term y) t + end + + (******************) + (* MATCHING UTILS *) + (******************) + + open Terms + + (** Since most of the folowing functions require an extra parameter, + the units, we package them in a module. This functor will then be + applied to a module containing the units, in the exported + functions. *) + module M (X : sig + val units : units + val is_ac : (symbol * bool) list + val strict : bool (* variables cannot be instantiated with units *) + end) = struct + + open X + + let print_units ()= + List.iter (fun (op,unit) -> Printf.printf "%i %i\t" op unit) units; + Printf.printf "\n%!" + + let mk_TAC s l = mk_TAC units s l + let mk_TA s l = mk_TA units s l + let mk_TAC' s l = + try return( mk_TAC s l) + with _ -> fail () + let mk_TA' s l = + try return( mk_TA s l) + with _ -> fail () + + (** First, we need to be able to perform non-deterministic choice of + term splitting to satisfy a pattern. Indeed, we want to show that: + (x+a*b)*c <= a*b*c + *) + let a_nondet_split_raw t : ('a list * 'a list) m = + let rec aux l l' = + match l' with + | [] -> + return ( l,[]) + | t::q -> + return ( l,l' ) + >>| aux (l @ [t]) q + in + aux [] t + + (** Same as the previous [a_nondet_split], but split the list in 3 + parts *) + let a_nondet_middle t : ('a list * 'a list * 'a list) m = + a_nondet_split_raw t >> + (fun (left, right) -> + a_nondet_split_raw left >> + (fun (left, middle) -> return (left, middle, right)) + ) + + (** Non deterministic splits of ac lists *) + let dispatch f n = + let rec aux k = + if k = 0 then return (f n 0) + else return (f (n-k) k) >>| aux (k-1) + in + aux (n ) + + let add_with_arith x ar l = + if ar = 0 then l else (x,ar) ::l + + let ac_nondet_split_raw (l : 'a mset) : ('a mset * 'a mset) m = + let rec aux = function + | [] -> return ([],[]) + | (t,tar)::q -> + aux q + >> + (fun (left,right) -> + dispatch (fun arl arr -> + add_with_arith t arl left, + add_with_arith t arr right + ) + tar + ) + in + aux l + + let a_nondet_split current t : (nf_term * nf_term list) m= + a_nondet_split_raw t + >> + (fun (l,r) -> + if strict && (l=[]) + then fail() + else + mk_TA' current l >> + fun t -> return (t, r) + ) + + let ac_nondet_split current t : (nf_term * nf_term mset) m= + ac_nondet_split_raw t + >> + (fun (l,r) -> + if strict && (l=[]) + then fail() + else + mk_TAC' current l >> + fun t -> return (t, r) + ) + + + (** Try to affect the variable [x] to each left factor of [t]*) + let var_a_nondet_split env current x t = + a_nondet_split current t + >> + (fun (t,res) -> + return ((Subst.add env x t), res) + ) + + (** Try to affect variable [x] to _each_ subset of t. *) + let var_ac_nondet_split (current: symbol) env (x : var) (t : nf_term mset) : (Subst.t * (nf_term mset)) m = + ac_nondet_split current t + >> + (fun (t,res) -> + return ((Subst.add env x t), res) + ) + + (** See the term t as a given AC symbol. Unwrap the first constructor + if necessary *) + let get_AC (s : symbol) (t : nf_term) : (nf_term *int) list = + match t with + | TAC (s',l) when s' = s -> l + | TUnit u when is_unit units s u -> [] + | _ -> [t,1] + + (** See the term t as a given A symbol. Unwrap the first constructor + if necessary *) + let get_A (s : symbol) (t : nf_term) : nf_term list = + match t with + | TA (s',l) when s' = s -> l + | TUnit u when is_unit units s u -> [] + | _ -> [t] + + (** See the term [t] as an symbol [s]. Fail if it is not such + symbol. *) + let get_Sym s t = + match t with + | TSym (s',l) when s' = s -> return l + | _ -> fail () + + (*************) + (* A Removal *) + (*************) + + (** We remove the left factor v in a term list. This function runs + linearly with respect to the size of the first pattern symbol *) + + let left_factor current (v : nf_term) (t : nf_term list) = + let rec aux a b = + match a,b with + | t::q , t' :: q' when nf_equal t t' -> aux q q' + | [], q -> return q + | _, _ -> fail () + in + match v with + | TA (s,l) when s = current -> aux l t + | TUnit u when is_unit units current u -> return t + | _ -> + begin match t with + | [] -> fail () + | t::q -> + if nf_equal v t + then return q + else fail () + end + + + (**************) + (* AC Removal *) + (**************) + + (** {!pick_sym} gather all elements of a list that satisfies a + predicate, and combine them with the residual of the list. That + is, each element of the residual contains exactly one element less + than the original term. + + TODO : This function not as efficient as it could be, using a + proper data-structure *) + + let pick_sym (s : symbol) (t : nf_term mset ) = + let rec aux front back = + match back with + | [] -> fail () + | (t,tar)::q -> + begin match t with + | TSym (s',v') when s = s' -> + let back = + if tar > 1 + then (t,tar -1) ::q + else q + in + return (v' , List.rev_append front back ) + >>| aux ((t,tar)::front) q + | _ -> aux ((t,tar)::front) q + end + in + aux [] t + + + + (** We have to check if we are trying to remove a unit from a term. Could also be located in Terms*) + let is_unit_AC s t = + try nf_equal t (mk_TAC s []) + with | NoUnit -> false + + let is_same_AC s t : nf_term mset option= + match t with + TAC (s',l) when s = s' -> Some l + | _ -> None + + (** We want to remove the term [v] from the term list [t] under an AC + symbol *) + let single_AC_factor (s : symbol) (v : nf_term) v_ar (t : nf_term mset) : (nf_term mset) m = + let rec aux front back = + match back with + | [] -> fail () + | (t,tar)::q -> + begin + if nf_equal v t + then + match () with + | _ when tar < v_ar -> fail () + | _ when tar = v_ar -> return (List.rev_append front q) + | _ -> return (List.rev_append front ((t,tar-v_ar)::q)) + else + aux ((t,tar) :: front) q + end + in + if is_unit_AC s v + then + return t + else + aux [] t + + (* Remove a constant from a mset. If this constant is also a mset for + the same symbol, we remove every elements, one at a time (and we do + care of the arities) *) + let factor_AC (s : symbol) (v: nf_term) (t : nf_term mset) : ( nf_term mset ) m = + match is_same_AC s v with + | None -> single_AC_factor s v 1 t + | Some l -> + (* We are trying to remove an AC factor *) + List.fold_left (fun acc (v,v_ar) -> + acc >> (single_AC_factor s v v_ar) + ) + (return t) + l + + +(** [tri_fold f l acc] folds on the list [l] and give to f the + beginning of the list in reverse order, the considered element, and + the last part of the list + + as an exemple, on the list [1;2;3;4], we get the trace + f () [] 1 [2; 3; 4] + f () [1] 2 [3; 4] + f () [2;1] 3 [ 4] + f () [3;2;1] 4 [] + + it is the duty of the user to reverse the front if needed +*) + +let tri_fold f (l : 'a list) (acc : 'b)= match l with + [] -> acc + | _ -> + let _,_,acc = List.fold_left (fun acc (t : 'a) -> + let l,r,acc = acc in + let r = List.tl r in + t::l,r,f acc l t r + ) ([], l,acc) l + in acc + +(* let test = tri_fold (fun acc l t r -> (l,t,r) :: acc) [1;2;3;4] [] *) + + + + (*****************************) + (* ENUMERATION DES POSITIONS *) + (*****************************) + + +(** The pattern is a unit: we need to try to make it appear at each + position. We do not need to go further with a real matching, since + the match should be trivial. Hence, we proceed recursively to + enumerate all the positions. *) + +module Positions = struct + + + let ac (l: 'a mset) : ('a mset * 'a )m = + let rec aux = function + | [] -> assert false + | [t,1] -> return ([],t) + | [t,tar] -> return ([t,tar -1],t) + | (t,tar) as h :: q -> + (aux q >> (fun (c,x) -> return (h :: c,x))) + >>| (if tar > 1 then return ((t,tar-1) :: q,t) else return (q,t)) + in + aux l + + let ac' current (l: 'a mset) : ('a mset * 'a )m = + ac_nondet_split_raw l >> + (fun (l,r) -> + if l = [] || r = [] + then fail () + else + mk_TAC' current r >> + fun t -> return (l, t) + ) + + let a (l : 'a list) : ('a list * 'a * 'a list) m = + let rec aux left right : ('a list * 'a * 'a list) m = + match right with + | [] -> assert false + | [t] -> return (left,t,[]) + | t::q -> + aux (left@[t]) q + >>| return (left,t,q) + in + aux [] l +end + +let build_ac (current : symbol) (context : nf_term mset) (p : t) : t m= + if context = [] + then return p + else + mk_TAC' current context >> + fun t -> return (Plus (current,t_of_term t,p)) + +let build_a (current : symbol) + (left : nf_term list) (right : nf_term list) (p : t) : t m= + let right_pat p = + if right = [] + then return p + else + mk_TA' current right >> + fun t -> return (Dot (current,p,t_of_term t)) + in + let left_pat p= + if left = [] + then return p + else + mk_TA' current left >> + fun t -> return (Dot (current,t_of_term t,p)) + in + right_pat p >> left_pat >> (fun p -> return p) + + +let conts (hole : t) (l : symbol list) p : t m = + let p = t_of_term p in + (* - aller chercher les symboles ac et les symboles a + - construire pour chaque + * * + + / \ / \ / \ + 1 p p 1 p 1 + *) + let ac,a = List.partition (fun s -> List.assoc s is_ac) l in + let acc = fail () in + let acc = List.fold_left + (fun acc s -> + acc >>| return (Plus (s,p,hole)) + ) acc ac in + let acc = + List.fold_left + (fun acc s -> + acc >>| return (Dot (s,p,hole)) >>| return (Dot (s,hole,p)) + ) acc a + in acc + + +(** + Return the same thing as subterm : + - The size of the context + - The context + - A collection of substitutions (here == return Subst.empty) +*) +let unit_subterm (t : nf_term) (unit : symbol) (hole : t): (int * t * Subst.t m) m = + let symbols = List.fold_left + (fun acc (ac,u) -> if u = unit then ac :: acc else acc ) [] units + in + (* make a unit appear at each strict sub-position of the term*) + let rec positions (t : nf_term) : t m = + match t with + (* with a final unit at the end *) + | TAC (s,l) -> + let symbols' = List.filter (fun x -> x <> s) symbols in + ( + ac_nondet_split_raw l >> + (fun (l,r) -> if l = [] || r = [] then fail () else + ( + match r with + | [p,1] -> + positions p >>| conts hole symbols' p + | _ -> + mk_TAC' s r >> conts hole symbols' + ) >> build_ac s l )) + | TA (s,l) -> + let symbols' = List.filter (fun x -> x <> s) symbols in + ( + (* first the other symbols, and then the more simple case of + this particular symbol *) + a_nondet_middle l >> + (fun (l,m,r) -> + (* meant to break the symmetry *) + if (l = [] && r = []) + then fail () + else + ( + match m with + | [p] -> + positions p >>| conts hole symbols' p + | _ -> + mk_TA' s m >> conts hole symbols' + ) >> build_a s l r )) + >>| + ( + if List.mem s symbols then + begin match l with + | [a] -> assert false + | [a;b] -> build_a s [a] [b] (hole) + | _ -> + (* on ne construit que les elements interieurs, + d'ou la disymetrie *) + (Positions.a l >> + (fun (left,p,right) -> + if left = [] then fail () else + (build_a s left right (Dot (s,hole,t_of_term p))))) + end + else fail () + ) + + | TSym (s,l) -> + tri_fold (fun acc l t r -> + ((positions t) >> + (fun (p) -> + let l = List.map t_of_term l in + let r = List.map t_of_term r in + return (Sym (s, Array.of_list (List.rev_append l (p::r)))) )) + >>| + ( + conts hole symbols t >> + (fun (p) -> + let l = List.map t_of_term l in + let r = List.map t_of_term r in + return (Sym (s, Array.of_list (List.rev_append l (p::r)))) ) + ) + >>| acc + ) l (fail()) + | TVar x -> assert false + | TUnit x when x = unit -> return (hole) + | TUnit x as t -> conts hole symbols t + in + (positions t + >>| + (match t with + | TSym _ -> conts hole symbols t + | TAC (s,l) -> conts hole symbols t + | TA (s,l) -> conts hole symbols t + | _ -> fail()) + ) + >> fun (p) -> return (Terms.size p,p,return Subst.empty) + + + + + (************) + (* Matching *) + (************) + + + +(** {!matching} is the generic matching judgement. Each time a + non-deterministic split is made, we have to come back to this one. + + {!matchingSym} is used to match two applications that have the + same (free) head-symbol. + + {!matchingAC} is used to match two sums (with the subtlety that + [x+y] matches [f a] which is a function application or [a*b] which + is a product). + + {!matchingA} is used to match two products (with the subtlety that + [x*y] matches [f a] which is a function application, or [a+b] + which is a sum). + + +*) + +let matching : Subst.t -> nf_term -> nf_term -> Subst.t Search.m= + let rec matching env (p : nf_term) (t: nf_term) : Subst.t Search.m= + match p with + | TAC (s,l) -> + let l = linear l in + matchingAC env s l (get_AC s t) + | TA (s,l) -> + matchingA env s l (get_A s t) + | TSym (s,l) -> + (get_Sym s t) + >> (fun t -> matchingSym env l t) + | TVar x -> + begin match Subst.find env x with + | None -> return (Subst.add env x t) + | Some v -> if nf_equal v t then return env else fail () + end + | TUnit s -> + if nf_equal p t then return env else fail () + and + matchingAC (env : Subst.t) (current : symbol) (l : nf_term list) (t : (nf_term *int) list) = + match l with + | TSym (s,v):: q -> + pick_sym s t + >> (fun (v',t') -> + matchingSym env v v' + >> (fun env -> matchingAC env current q t')) + + | TAC (s,v)::q when s = current -> + assert false + | TVar x:: q -> (* This is an optimization *) + begin match Subst.find env x with + | None -> + (var_ac_nondet_split current env x t) + >> (fun (env,residual) -> matchingAC env current q residual) + | Some v -> + (factor_AC current v t) + >> (fun residual -> matchingAC env current q residual) + end + | TUnit s as v :: q -> (* this is an optimization *) + (factor_AC current v t) >> + (fun residual -> matchingAC env current q residual) + | h :: q ->(* PAC =/= curent or PA or unit for this symbol*) + (ac_nondet_split current t) + >> + (fun (t,right) -> + matching env h t + >> + ( + fun env -> + matchingAC env current q right + ) + ) + | [] -> if t = [] then return env else fail () + and + matchingA (env : Subst.t) (current : symbol) (l : nf_term list) (t : nf_term list) = + match l with + | TSym (s,v) :: l -> + begin match t with + | TSym (s',v') :: r when s = s' -> + (matchingSym env v v') + >> (fun env -> matchingA env current l r) + | _ -> fail () + end + | TA (s,v) :: l when s = current -> + assert false + | TVar x :: l -> + begin match Subst.find env x with + | None -> + debug (Printf.sprintf "var %i (%s)" x + (let b = Buffer.create 21 in List.iter (fun t -> Buffer.add_string b ( Terms.sprint_nf_term t)) t; Buffer.contents b )); + var_a_nondet_split env current x t + >> (fun (env,residual)-> debug (Printf.sprintf "pl %i %i" x(List.length residual)); matchingA env current l residual) + | Some v -> + (left_factor current v t) + >> (fun residual -> matchingA env current l residual) + end + | TUnit s as v :: q -> (* this is an optimization *) + (left_factor current v t) >> + (fun residual -> matchingA env current q residual) + | h :: l -> + a_nondet_split current t + >> (fun (t,r) -> + matching env h t + >> (fun env -> matchingA env current l r) + ) + | [] -> if t = [] then return env else fail () + and + matchingSym (env : Subst.t) (l : nf_term list) (t : nf_term list) = + List.fold_left2 + (fun acc p t -> acc >> (fun env -> matching env p t)) + (return env) + l + t + + in + fun env l r -> + let _ = debug (Printf.sprintf "pattern :%s\nterm :%s\n%!" (Terms.sprint_nf_term l) (Terms.sprint_nf_term r)) in + let m = matching env l r in + let _ = debug (Printf.sprintf "count %i" (Search.count m)) in + m + + +(** [unitifiable p : Subst.t m] *) +let unitifiable p : (symbol * Subst.t m) m = + let m = List.fold_left + (fun acc (_,unit) -> acc >>| + let m = matching Subst.empty p (mk_TUnit unit) in + if Search.is_empty m then + fail () + else + begin + return (unit,m) + end + ) (fail ()) units + in + m +;; + +let nullifiable p = + let nullable = not strict in + let has_unit s = + try let _ = get_unit units s in true + with NoUnit -> false + in + let rec aux = function + | TA (s,l) -> has_unit s && List.for_all (aux) l + | TAC (s,l) -> has_unit s && List.for_all (fun (x,n) -> aux x) l + | TSym _ -> false + | TVar _ -> nullable + | TUnit _ -> true + in aux p + +let unit_warning p ~nullif ~unitif = + assert ((Search.is_empty unitif) || nullif); + if not (Search.is_empty unitif) + then + begin + Pp.msg_warning + (Pp.str + "[aac_tactics] This pattern can be instanciated to match units, some solutions can be missing"); + end + +;; + + + + +(***********) +(* Subterm *) +(***********) + + + +(** [subterm] solves a sub-term pattern matching. + + This function is more high-level than {!matcher}, thus takes {!t} + as arguments rather than terms in normal form {!nf_term}. + + We use three mutually recursive functions {!subterm}, + {!subterm_AC}, {!subterm_A} to find the matching subterm, making + non-deterministic choices to split the term into a context and an + intersting sub-term. Intuitively, the only case in which we have to + go in depth is when we are left with a sub-term that is atomic. + + Indeed, rewriting [H: b = c |- a+b+a = a+a+c], we do not want to + find recursively the sub-terms of [a+b] and [b+a], since they will + overlap with the sub-terms of [a+b+a]. + + We rebuild the context on the fly, leaving the variables in the + pattern uninstantiated. We do so in order to allow interaction + with the user, to choose the env. + + Strange patterms like x*y*x can be instanciated by nothing, inside + a product. Therefore, we need to check that all the term is not + going into the context. With proper support for interaction with + the user, we should lift these tests. However, at the moment, they + serve as heuristics to return "interesting" matchings +*) + + let return_non_empty raw_p m = + if is_empty m + then + fail () + else + return (raw_p ,m) + + let subterm (raw_p:t) (raw_t:t): (int * t * Subst.t m) m= + let _ = debug (String.make 40 '=') in + let p = term_of_t units raw_p in + let t = term_of_t units raw_t in + let nullif = nullifiable p in + let unitif = unitifiable p in + let _ = unit_warning p ~nullif ~unitif in + let _ = debug (Printf.sprintf "%s" (Terms.sprint_nf_term p)) in + let _ = debug (Printf.sprintf "%s" (Terms.sprint_nf_term t)) in + let filter_right current right (p,m) = + if right = [] + then return (p,m) + else + mk_TAC' current right >> + fun t -> return (Plus (current,p,t_of_term t),m) + in + let filter_middle current left right (p,m) = + let right_pat p = + if right = [] + then return p + else + mk_TA' current right >> + fun t -> return (Dot (current,p,t_of_term t)) + in + let left_pat p= + if left = [] + then return p + else + mk_TA' current left >> + fun t -> return (Dot (current,t_of_term t,p)) + in + right_pat p >> left_pat >> (fun p -> return (p,m)) + in + let rec subterm (t:nf_term) : (t * Subst.t m) m= + match t with + | TAC (s,l) -> + ((ac_nondet_split_raw l) >> + (fun (left,right) -> + (subterm_AC s left) >> (filter_right s right) + )) + | TA (s,l) -> + (a_nondet_middle l) >> + (fun (left, middle, right) -> + (subterm_A s middle) >> + (filter_middle s left right) + ) + | TSym (s, l) -> + let init = + return_non_empty raw_p (matching Subst.empty p t) + in + tri_fold (fun acc l t r -> + ((subterm t) >> + (fun (p,m) -> + let l = List.map t_of_term l in + let r = List.map t_of_term r in + let p = Sym (s, Array.of_list (List.rev_append l (p::r))) in + return (p,m) + )) >>| acc + ) l init + | TVar x -> assert false + (* this case is superseded by the later disjunction *) + | TUnit s -> fail () + + and subterm_AC s tl = + match tl with + [x,1] -> subterm x + | _ -> + mk_TAC' s tl >> fun t -> + return_non_empty raw_p (matching Subst.empty p t) + and subterm_A s tl = + match tl with + [x] -> subterm x + | _ -> + mk_TA' s tl >> + fun t -> + return_non_empty raw_p (matching Subst.empty p t) + in + match p with + | TUnit unit -> unit_subterm t unit raw_p + | _ when not (Search.is_empty unitif) -> + let unit_matches = + Search.fold + (fun (unit,inst) acc -> + Search.fold + (fun subst acc' -> + let m = unit_subterm t unit (Subst.instantiate subst raw_p) + in + m>>| acc' + ) + inst + acc + ) + unitif + (fail ()) + in + let nullifies (t : Subst.t) = + List.for_all (fun (_,x) -> + List.exists (fun (_,y) -> Unit y = x ) units + ) (Subst.to_list t) + in + let nonunit_matches = + subterm t >> + ( + fun (p,m) -> + let m = Search.filter (fun subst -> not( nullifies subst)) m in + if Search.is_empty m + then fail () + else return (Terms.size p,p,m) + ) + in + unit_matches >>| nonunit_matches + + | _ -> (subterm t >> fun (p,m) -> return (Terms.size p,p,m)) + + + end + + +(* The functions we export, handlers for the previous ones. Some debug + information also *) +let subterm ?(strict = false) units raw t = + let module M = M (struct + let is_ac = units.is_ac + let units = units.unit_for + let strict = strict + end) in + let sols = time (M.subterm raw) t "%fs spent in subterm (including matching)\n" in + debug + (Printf.sprintf "%i possible solution(s)\n" + (Search.fold (fun (_,_,envm) acc -> count envm + acc) sols 0)); + sols + + +let matcher ?(strict = false) units p t = + let module M = M (struct + let is_ac = units.is_ac + let units = units.unit_for + let strict = false + end) in + let units = units.unit_for in + let sols = time + (fun (p,t) -> + let p = (Terms.term_of_t units p) in + let t = (Terms.term_of_t units t) in + M.matching Subst.empty p t) (p,t) + "%fs spent in the matcher\n" + in + debug (Printf.sprintf "%i solutions\n" (count sols)); + sols + -- cgit v1.2.3