diff options
author | 2014-07-21 10:03:04 +0200 | |
---|---|---|
committer | 2014-08-05 18:38:28 +0200 | |
commit | 7dba9d3f3ce62246b9d8562d2818c63ba37b206e (patch) | |
tree | fbf0e133e160a5f7ff03f8a0b5bb4d0f47b43105 /stm/stm.ml | |
parent | 4e724634839726aa11534f16e4bfb95cd81232a4 (diff) |
STM: new "par:" goal selector, like "all:" but in parallel
par: distributes the goals among a number of workers given
by -async-proofs-tac-j (defaults to 2).
Diffstat (limited to 'stm/stm.ml')
-rw-r--r-- | stm/stm.ml | 215 |
1 files changed, 184 insertions, 31 deletions
diff --git a/stm/stm.ml b/stm/stm.ml index 0827e0bfa..20448947f 100644 --- a/stm/stm.ml +++ b/stm/stm.ml @@ -13,6 +13,7 @@ let prerr_endline s = if !Flags.debug then begin pr_err s end else () let (f_process_error, process_error_hook) = Hook.make () let ((f_interp : (?verbosely:bool -> ?proof:Proof_global.closed_proof -> Loc.t * Vernacexpr.vernac_expr -> unit) Hook.value), interp_hook) = Hook.make () +let with_fail, with_fail_hook = Hook.make () open Vernacexpr open Errors @@ -82,7 +83,7 @@ type branch_type = [ `Master | `Proof of proof_mode * depth | `Edit of proof_mode * Stateid.t * Stateid.t ] -type cmd_t = ast * Id.t list +type cmd_t = ast * Id.t list * bool type fork_t = ast * Vcs_.Branch.t * Vernacexpr.opacity_guarantee * Id.t list type qed_t = { qast : ast; @@ -249,7 +250,7 @@ end = struct let fname = "stm_" ^ Str.global_replace (Str.regexp " ") "_" (System.process_id ()) in let string_of_transaction = function - | Cmd (t, _) | Fork (t, _,_,_) -> + | Cmd (t, _, _) | Fork (t, _,_,_) -> (try string_of_ppcmds (pr_ast t) with _ -> "ERR") | Sideff (Some t) -> sprintf "Sideff(%s)" @@ -538,7 +539,7 @@ module State : sig ?redefine:bool -> ?cache:Summary.marshallable -> (unit -> unit) -> Stateid.t -> unit val install_cached : Stateid.t -> unit - val is_cached : Stateid.t -> bool + val is_cached : ?cache:Summary.marshallable -> Stateid.t -> bool val exn_on : Stateid.t -> ?valid:Stateid.t -> exn -> exn @@ -566,13 +567,23 @@ end = struct let () = Future.set_freeze (fun () -> in_t (freeze_global_state `No, !cur_id)) (fun t -> let s,i = out_t t in unfreeze_global_state s; cur_id := i) + + type frozen_state = state - let is_cached id = - Stateid.equal id !cur_id || - try match VCS.get_info id with - | { state = Some _ } -> true - | _ -> false - with VCS.Expired -> false + let freeze marhallable id = VCS.set_state id (freeze_global_state marhallable) + + let is_cached ?(cache=`No) id = + if Stateid.equal id !cur_id then + try match VCS.get_info id with + | { state = None } when cache = `Yes -> freeze `No id; true + | { state = None } when cache = `Shallow -> freeze `Shallow id; true + | _ -> true + with VCS.Expired -> false + else + try match VCS.get_info id with + | { state = Some _ } -> true + | _ -> false + with VCS.Expired -> false let install_cached id = if Stateid.equal id !cur_id then () else (* optimization *) @@ -582,8 +593,6 @@ end = struct | _ -> anomaly (str "unfreezing a non existing state") in unfreeze_global_state s; cur_id := id - type frozen_state = state - let get_cached id = try match VCS.get_info id with | { state = Some s } -> s @@ -594,8 +603,6 @@ end = struct try if VCS.get_state id = None then VCS.set_state id s with VCS.Expired -> () - let freeze marhallable id = VCS.set_state id (freeze_global_state marhallable) - let exn_on id ?valid e = match Stateid.get e with | Some _ -> e @@ -700,7 +707,8 @@ module Task = struct let name_of_task t = t.t_name let name_of_request r = r.r_name - let request_of_task { t_exn_info; t_start; t_stop; t_loc; t_uuid; t_name } = + let request_of_task age { t_exn_info;t_start;t_stop;t_loc;t_uuid;t_name } = + assert(age = `Fresh); try Some { r_exn_info = t_exn_info; r_stop = t_stop; @@ -764,7 +772,7 @@ module Task = struct VCS.print (); RespBuiltProof(rc,time) with - |e when Errors.noncritical e -> + | e when Errors.noncritical e -> (* This can happen if the proof is broken. The error has also been * signalled as a feedback, hence we can silently recover *) let e_error_at, e_safe_id = match Stateid.get e with @@ -877,7 +885,7 @@ end = struct spc () ++ print e) | Some (_, cur) -> match VCS.visit cur with - | { step = `Cmd ( { loc }, _) } + | { step = `Cmd ( { loc }, _, _) } | { step = `Fork ( { loc }, _, _, _) } | { step = `Qed ( { qast = { loc } }, _) } | { step = `Sideff (`Ast ( { loc }, _)) } -> @@ -938,9 +946,9 @@ end = struct let set_perspective idl = let open Stateid in let p = List.fold_right Set.add idl Set.empty in - TQueue.set_order queue (fun task1 task2 -> - let TaskBuildProof (_, a1, b1, _, _,_,_,_) = task1 in - let TaskBuildProof (_, a2, b2, _, _,_,_,_) = task2 in + TaskQueue.set_order (fun task1 task2 -> + let { Task.t_start = a1; Task.t_stop = b1 } = task1 in + let { Task.t_start = a2; Task.t_stop = b2 } = task2 in match Set.mem a1 p || Set.mem b1 p, Set.mem a2 p || Set.mem b2 p with | true, true | false, false -> 0 | true, false -> -1 @@ -983,10 +991,150 @@ end = struct let tasks = TaskQueue.dump () in prerr_endline (Printf.sprintf "dumping %d tasks\n" (List.length tasks)); List.map (function r -> { r with r_uuid = List.assoc r.r_uuid f2t_map }) - (CList.map_filter Task.request_of_task tasks) + (CList.map_filter (Task.request_of_task `Fresh) tasks) + +end + +module SubTask = struct + + let reach_known_state = ref (fun ?redefine_qed ~cache id -> ()) + let set_reach_known_state f = reach_known_state := f + + type output = Constr.constr * Evd.evar_universe_context + + let forward_feedback = forward_feedback + type task = { + t_state : Stateid.t; + t_state_fb : Stateid.t; + t_assign : output Future.assignement -> unit; + t_ast : ast; + t_goal : Goal.goal; + t_kill : unit -> unit; + t_name : string } + + type request = { + r_state : Stateid.t; + r_state_fb : Stateid.t; + r_document : VCS.vcs option; + r_ast : ast; + r_goal : Goal.goal; + r_name : string } + + type response = + | RespBuiltSubProof of output + | RespError of std_ppcmds + + let name = "tacworker" + let extra_env () = [||] + + (* run by the master, on a thread *) + let request_of_task age { t_state; t_state_fb; t_ast; t_goal; t_name } = + try Some { + r_state = t_state; + r_state_fb = t_state_fb; + r_document = + if age = `Old then None + else Some (VCS.slice ~start:t_state ~stop:t_state); + r_ast = t_ast; + r_goal = t_goal; + r_name = t_name } + with VCS.Expired -> None + + let use_response { t_assign; t_state; t_state_fb; t_kill } = function + | RespBuiltSubProof o -> t_assign (`Val o); `Stay + | RespError msg -> + let e = Stateid.add ~valid:t_state (RemoteException msg) t_state_fb in + t_assign (`Exn e); + t_kill (); + `Stay + + let on_marshal_error err { t_name } = + pr_err ("Fatal marshal error: " ^ t_name ); + flush_all (); exit 1 + + let on_slave_death task = `Stay + let on_task_cancellation_or_expiration task = () (* We shall die *) + + let perform { r_state = id; r_state_fb; r_document = vcs; r_ast; r_goal } = + Option.iter VCS.restore vcs; + try + !reach_known_state ~cache:`No id; + let t, uc = Future.purify (fun () -> + vernac_interp r_state_fb r_ast; + let _,_,_,_,sigma = Proof.proof (Proof_global.give_me_the_proof ()) in + match Goal.solution sigma r_goal with + | None -> Errors.errorlabstrm "Stm" (str "no progress") + | Some t -> + let t = Evarutil.nf_evar sigma t in + if Evarutil.is_ground_term sigma t then + t, Evd.evar_universe_context sigma + else Errors.errorlabstrm "Stm" (str"The solution is not ground")) + () in + RespBuiltSubProof (t,uc) + with e when Errors.noncritical e -> RespError (Errors.print e) + + let name_of_task { t_name } = t_name + let name_of_request { r_name } = r_name + end +module Partac = struct + + module TaskQueue = AsyncTaskQueue.Make(SubTask) + + let vernac_interp nworkers safe_id id { verbose; loc; expr = e } = + let e, etac, time, fail = + let rec find time fail = function VernacSolve(_,re,b) -> re, b, time, fail + | VernacTime [_,e] -> find true fail e + | VernacFail e -> find time true e + | _ -> errorlabstrm "Stm" (str"unsupported") in find false false e in + Hook.get with_fail fail (fun () -> + (if time then System.with_time false else (fun x -> x)) (fun () -> + ignore(TaskQueue.with_n_workers nworkers (fun ~join ~cancel_all -> + Proof_global.with_current_proof (fun _ p -> + let goals, _, _, _, _ = Proof.proof p in + let open SubTask in + let res = CList.map_i (fun i g -> + let f,assign= Future.create_delegate (State.exn_on id ~valid:safe_id) in + let t_ast = { verbose;loc;expr = VernacSolve(SelectNth i,e,etac) } in + let t_name = Goal.uid g in + TaskQueue.enqueue_task + { t_state = safe_id; t_state_fb = id; + t_assign = assign; t_ast; t_goal = g; t_name; t_kill = cancel_all } + (ref false); + Goal.uid g,f) + 1 goals in + join (); + let assign_tac : unit Proofview.tactic = + Proofview.V82.tactic (fun gl -> + let open Tacmach in + let sigma, g = project gl, sig_it gl in + let gid = Goal.uid g in + let f = + try List.assoc gid res + with Not_found -> Errors.anomaly(str"Partac: wrong focus") in + if Future.is_over f then + let pt, uc = Future.join f in + prerr_endline Pp.(string_of_ppcmds(hov 0 ( + str"g=" ++ str gid ++ spc () ++ + str"t=" ++ (Printer.pr_constr pt) ++ spc () ++ + str"uc=" ++ Evd.pr_evar_universe_context uc))); + let sigma = Goal.V82.partial_solution sigma g pt in + let sigma = Evd.merge_universe_context sigma uc in + re_sig [] sigma + else (* One has failed and cancelled the others, but not this one *) + re_sig [g] sigma) in + Proof.run_tactic (Global.env()) assign_tac p)))) ()) + + let slave_main_loop = TaskQueue.slave_main_loop + let slave_init_stdout = TaskQueue.slave_init_stdout + +end + +let tacslave_main_loop () = Partac.slave_main_loop Ephemeron.clear +let tacslave_init_stdout = Partac.slave_init_stdout + (* Runs all transactions needed to reach a state *) module Reach : sig @@ -1019,7 +1167,7 @@ let collect_proof cur hd brkind id = let rec collect last accn id = let view = VCS.visit id in match last, view.step with - | _, `Cmd (x, _) -> collect (Some (id,x)) (id::accn) view.next + | _, `Cmd (x, _, _) -> collect (Some (id,x)) (id::accn) view.next | _, `Alias _ -> `Sync (no_name,`Alias) | _, `Fork(_,_,_,_::_::_)-> `Sync (no_name,`MutualProofs) | _, `Fork(_,_,Doesn'tGuaranteeOpacity,_) -> @@ -1099,7 +1247,7 @@ let known_state ?(redefine_qed=false) ~cache id = (* traverses the dag backward from nodes being already calculated *) and reach ?(redefine_qed=false) ?(cache=cache) id = prerr_endline ("reaching: " ^ Stateid.to_string id); - if not redefine_qed && State.is_cached id then begin + if not redefine_qed && State.is_cached ~cache id then begin State.install_cached id; feedback ~state_id:id Feedback.Processed; prerr_endline ("reached (cache)") @@ -1110,9 +1258,13 @@ let known_state ?(redefine_qed=false) ~cache id = | `Alias id -> (fun () -> reach view.next; reach id ), cache - | `Cmd (x,_) -> (fun () -> + | `Cmd (x,_,false) -> (fun () -> reach view.next; vernac_interp id x ), cache + | `Cmd (x,_,true) -> (fun () -> + reach ~cache:`Shallow view.next; + Partac.vernac_interp !Flags.async_proofs_n_tacworkers view.next id x + ), cache | `Fork (x,_,_,_) -> (fun () -> reach view.next; vernac_interp id x; wall_clock_last_fork := Unix.gettimeofday () @@ -1205,6 +1357,7 @@ let known_state ?(redefine_qed=false) ~cache id = end let _ = Task.set_reach_known_state Reach.known_state +let _ = SubTask.set_reach_known_state Reach.known_state (* The backtrack module simulates the classic behavior of a linear document *) module Backtrack : sig @@ -1263,7 +1416,7 @@ end = struct if id = Stateid.initial || id = Stateid.dummy then [] else match VCS.visit id with | { step = `Fork (_,_,_,l) } -> l - | { step = `Cmd (_,l) } -> l + | { step = `Cmd (_,l,_) } -> l | _ -> [] in match f acc (id, vcs, ids) with | `Stop x -> x @@ -1550,7 +1703,7 @@ let process_transaction ?(newtip=Stateid.fresh ()) ~tty verbose c (loc, expr) = | VtQuery (true,report_id), w -> assert(Stateid.equal report_id Stateid.dummy); let id = VCS.new_node ~id:newtip () in - VCS.commit id (Cmd (x,[])); + VCS.commit id (Cmd (x,[],false)); Backtrack.record (); if w == VtNow then finish (); `Ok | VtQuery (false,_), VtLater -> anomaly(str"classifier: VtQuery + VtLater must imply part_of_script") @@ -1569,7 +1722,7 @@ let process_transaction ?(newtip=Stateid.fresh ()) ~tty verbose c (loc, expr) = | VtProofMode mode, VtNow -> let id = VCS.new_node ~id:newtip () in VCS.checkout VCS.Branch.master; - VCS.commit id (Cmd (x,[])); + VCS.commit id (Cmd (x,[],false)); VCS.propagate_sideff (Some x); List.iter (fun bn -> match VCS.get_branch bn with @@ -1585,9 +1738,9 @@ let process_transaction ?(newtip=Stateid.fresh ()) ~tty verbose c (loc, expr) = Backtrack.record (); finish (); `Ok - | VtProofStep, w -> + | VtProofStep paral, w -> let id = VCS.new_node ~id:newtip () in - VCS.commit id (Cmd (x,[])); + VCS.commit id (Cmd (x,[],paral)); Backtrack.record (); if w == VtNow then finish (); `Ok | VtQed keep, w -> let rc = merge_proof_branch ~id:newtip x keep head in @@ -1602,7 +1755,7 @@ let process_transaction ?(newtip=Stateid.fresh ()) ~tty verbose c (loc, expr) = | VtSideff l, w -> let id = VCS.new_node ~id:newtip () in VCS.checkout VCS.Branch.master; - VCS.commit id (Cmd (x,l)); + VCS.commit id (Cmd (x,l,false)); VCS.propagate_sideff (Some x); VCS.checkout_shallowest_proof_branch (); Backtrack.record (); if w == VtNow then finish (); `Ok @@ -1624,7 +1777,7 @@ let process_transaction ?(newtip=Stateid.fresh ()) ~tty verbose c (loc, expr) = VCS.branch bname (`Proof ("Classic", VCS.proof_nesting () + 1)); Proof_global.activate_proof_mode "Classic"; end else begin - VCS.commit id (Cmd (x,[])); + VCS.commit id (Cmd (x,[],false)); VCS.propagate_sideff (Some x); VCS.checkout_shallowest_proof_branch (); end in @@ -1848,7 +2001,7 @@ let get_script prf = | `Sideff (`Ast (x,_)) -> find ((x.expr, (VCS.get_info id).n_goals)::acc) view.next | `Sideff (`Id id) -> find acc id - | `Cmd (x,_) -> find ((x.expr, (VCS.get_info id).n_goals)::acc) view.next + | `Cmd (x,_,_) -> find ((x.expr, (VCS.get_info id).n_goals)::acc) view.next | `Alias id -> find acc id | `Fork _ -> find acc view.next in |