diff options
author | 2018-04-06 17:39:17 -0700 | |
---|---|---|
committer | 2018-04-06 18:18:06 -0700 | |
commit | 470cc0f75108e68965f89026399f7b3a7a08196b (patch) | |
tree | e4997f253781b176976f9baf9949d0a0e3751c8a | |
parent | 38d1ac1e4f5b2a6e88eee43d332292898e0afc41 (diff) |
Add remote session support for the MakeCallable API.
PiperOrigin-RevId: 191964391
18 files changed, 898 insertions, 156 deletions
diff --git a/tensorflow/core/distributed_runtime/local_master.cc b/tensorflow/core/distributed_runtime/local_master.cc index aaa4cfa734..76315462a7 100644 --- a/tensorflow/core/distributed_runtime/local_master.cc +++ b/tensorflow/core/distributed_runtime/local_master.cc @@ -157,6 +157,47 @@ Status LocalMaster::Reset(CallOptions* call_options, return ret; } +Status LocalMaster::MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) { + Notification n; + Status ret; + master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) { + ret.Update(s); + n.Notify(); + }); + TF_RETURN_IF_ERROR( + WaitForNotification(call_options, default_timeout_in_ms_, &n)); + return ret; +} +Status LocalMaster::RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) { + Notification n; + Status ret; + master_impl_->RunCallable(call_options, request, response, + [&n, &ret](const Status& s) { + ret.Update(s); + n.Notify(); + }); + TF_RETURN_IF_ERROR( + WaitForNotification(call_options, default_timeout_in_ms_, &n)); + return ret; +} +Status LocalMaster::ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) { + Notification n; + Status ret; + master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) { + ret.Update(s); + n.Notify(); + }); + TF_RETURN_IF_ERROR( + WaitForNotification(call_options, default_timeout_in_ms_, &n)); + return ret; +} + namespace { mutex* get_local_master_registry_lock() { static mutex local_master_registry_lock(LINKER_INITIALIZED); diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h index c20b40329a..cad6babad8 100644 --- a/tensorflow/core/distributed_runtime/local_master.h +++ b/tensorflow/core/distributed_runtime/local_master.h @@ -71,6 +71,16 @@ class LocalMaster : public MasterInterface { Status Reset(CallOptions* call_options, const ResetRequest* request, ResetResponse* response) override; + Status MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) override; + Status RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) override; + Status ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response); + // Registers the mapping from the given `target` to the given `master`. // // WARNING: The `master` pointer remains owned by the caller. It is diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc index 1a488303ac..f47502e844 100644 --- a/tensorflow/core/distributed_runtime/master.cc +++ b/tensorflow/core/distributed_runtime/master.cc @@ -611,4 +611,55 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp, }); } +void Master::MakeCallable(const MakeCallableRequest* req, + MakeCallableResponse* resp, MyClosure done) { + auto session = FindMasterSession(req->session_handle()); + if (session == nullptr) { + done(errors::Aborted("Session ", req->session_handle(), " is not found.")); + return; + } + + SchedClosure(std::bind( + [this, session, req, resp](MyClosure done) { + Status s = session->MakeCallable(*req, resp); + session->Unref(); + done(s); + }, + std::move(done))); +} + +void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req, + RunCallableResponse* resp, MyClosure done) { + auto session = FindMasterSession(req->session_handle()); + if (session == nullptr) { + done(errors::Aborted("Session ", req->session_handle(), " is not found.")); + return; + } + + SchedClosure(std::bind( + [this, session, opts, req, resp](MyClosure done) { + Status s = session->RunCallable(opts, *req, resp); + session->Unref(); + done(s); + }, + std::move(done))); +} + +void Master::ReleaseCallable(const ReleaseCallableRequest* req, + ReleaseCallableResponse* resp, MyClosure done) { + auto session = FindMasterSession(req->session_handle()); + if (session == nullptr) { + done(errors::Aborted("Session ", req->session_handle(), " is not found.")); + return; + } + + SchedClosure(std::bind( + [this, session, req, resp](MyClosure done) { + Status s = session->ReleaseCallable(*req, resp); + session->Unref(); + done(s); + }, + std::move(done))); +} + } // end namespace tensorflow diff --git a/tensorflow/core/distributed_runtime/master.h b/tensorflow/core/distributed_runtime/master.h index 678fc46bd7..dbb337fd48 100644 --- a/tensorflow/core/distributed_runtime/master.h +++ b/tensorflow/core/distributed_runtime/master.h @@ -61,6 +61,13 @@ class Master { // See tensorflow::Reset() and the comment on ResetRequest. void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done); + void MakeCallable(const MakeCallableRequest* req, MakeCallableResponse* resp, + MyClosure done); + void RunCallable(CallOptions* opts, const RunCallableRequest* req, + RunCallableResponse* resp, MyClosure done); + void ReleaseCallable(const ReleaseCallableRequest* req, + ReleaseCallableResponse* resp, MyClosure done); + private: typedef Master ME; diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h index bf6a2db3e2..a8ae3cba3c 100644 --- a/tensorflow/core/distributed_runtime/master_interface.h +++ b/tensorflow/core/distributed_runtime/master_interface.h @@ -89,6 +89,16 @@ class MasterInterface { virtual Status Reset(CallOptions* call_options, const ResetRequest* request, ResetResponse* response) = 0; + virtual Status MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) = 0; + virtual Status RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) = 0; + virtual Status ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) = 0; + protected: // NOTE: This should only be called by implementations of this // interface whose CreateRunStepResponse() method returns a diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 64adf35c5e..e0a5bb4c53 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -72,7 +72,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { client_graph_(std::move(cg)), session_opts_(session_opts), is_partial_(is_partial), - debug_opts_(bopts.callable_options.run_options().debug_options()), + callable_opts_(bopts.callable_options), worker_cache_(worker_cache), should_deregister_(should_deregister) { VLOG(1) << "Created ReffedClientGraph for node with " @@ -94,12 +94,18 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { const ClientGraph* client_graph() { return client_graph_.get(); } + const CallableOptions& callable_options() { return callable_opts_; } + std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step, int64 execution_count, const RunOptions& ropts) { return stats_publisher_->GetProfileHandler(step, execution_count, ropts); } + int64 get_and_increment_execution_count() { + return execution_count_.fetch_add(1); + } + // Turn RPC logging on or off, both at the WorkerCache used by this // master process, and at each remote worker in use for the current // partitions. @@ -178,6 +184,10 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp, CancellationManager* cm, const bool is_last_partial_run); + Status RunPartitions(const MasterEnv* env, int64 step_id, + int64 execution_count, PerStepState* pss, + CallOptions* call_opts, const RunCallableRequest& req, + RunCallableResponse* resp, CancellationManager* cm); // Calls workers to cleanup states for the step "step_id". Calls // `done` when all cleanup RPCs have completed. @@ -211,10 +221,11 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { const std::unique_ptr<ClientGraph> client_graph_; const SessionOptions session_opts_; const bool is_partial_; - const DebugOptions& debug_opts_; + const CallableOptions callable_opts_; WorkerCacheInterface* const worker_cache_; // Not owned. std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node_; const bool should_deregister_; + std::atomic<int64> execution_count_ = {0}; // Graph partitioned into per-location subgraphs. struct Part { @@ -269,6 +280,17 @@ class MasterSession::ReffedClientGraph : public core::RefCounted { const PartitionOptions& popts, std::unordered_map<string, GraphDef> graph_partitions); + // Prepares a number of calls to workers. One call per partition. + // This is a generic method that handles Run, PartialRun, and RunCallable. + template <class FetchListType, class ClientRequestType, + class ClientResponseType> + Status RunPartitionsHelper( + const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds, + const FetchListType& fetches, const MasterEnv* env, int64 step_id, + int64 execution_count, PerStepState* pss, CallOptions* call_opts, + const ClientRequestType& req, ClientResponseType* resp, + CancellationManager* cm, bool is_last_partial_run); + // Deregisters the partitions on the workers. Called in the // destructor and does not wait for the rpc completion. void DeregisterPartitions(); @@ -411,7 +433,8 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( c->req.set_session_handle(session_handle_); c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]); *c->req.mutable_graph_options() = session_opts_.config.graph_options(); - *c->req.mutable_debug_options() = debug_opts_; + *c->req.mutable_debug_options() = + callable_opts_.run_options().debug_options(); VLOG(2) << "Register " << c->req.graph_def().DebugString(); auto cb = [c, &done](const Status& s) { c->status = s; @@ -490,24 +513,46 @@ class RunManyGraphs { TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs); }; -Status MasterSession::ReffedClientGraph::RunPartitions( - const MasterEnv* env, int64 step_id, int64 execution_count, - PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req, - MutableRunStepResponseWrapper* resp, CancellationManager* cm, - const bool is_last_partial_run) { - VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " - << execution_count; - // Maps the names of fed tensors to their index in `req`. - std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3); +namespace { +Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req, + MutableRunGraphRequestWrapper* worker_req, + size_t index, const string& send_key) { + return worker_req->AddSendFromRunStepRequest(client_req, index, send_key); +} - for (size_t i = 0; i < req.num_feeds(); ++i) { - if (!feeds.insert({req.feed_name(i), i}).second) { - return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i)); - } - } +Status AddSendFromClientRequest(const RunCallableRequest& client_req, + MutableRunGraphRequestWrapper* worker_req, + size_t index, const string& send_key) { + return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key); +} - // Prepares a number of calls to workers. One call per partition. +// TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for +// in-process messages. +struct RunCallableResponseWrapper { + RunCallableResponse* resp; // Not owned. + std::unordered_map<string, TensorProto> fetch_key_to_protos; + + RunMetadata* mutable_metadata() { return resp->mutable_metadata(); } + Status AddTensorFromRunGraphResponse( + const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp, + size_t index) { + // TODO(b/74355905): Add a specialized implementation that avoids + // copying the tensor into the RunCallableResponse when at least + // two of the {client, master, worker} are in the same process. + return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]); + } +}; +} // namespace + +template <class FetchListType, class ClientRequestType, + class ClientResponseType> +Status MasterSession::ReffedClientGraph::RunPartitionsHelper( + const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds, + const FetchListType& fetches, const MasterEnv* env, int64 step_id, + int64 execution_count, PerStepState* pss, CallOptions* call_opts, + const ClientRequestType& req, ClientResponseType* resp, + CancellationManager* cm, bool is_last_partial_run) { // Collect execution cost stats on a smoothly decreasing frequency. ExecutorOpts exec_opts; if (pss->report_tensor_allocations_upon_oom) { @@ -553,28 +598,19 @@ Status MasterSession::ReffedClientGraph::RunPartitions( // We keep these as separate paths for now, to ensure we aren't // inadvertently slowing down the normal run path. if (is_partial_) { - for (size_t i = 0; i < req.num_feeds(); ++i) { - const string& name = req.feed_name(i); - const auto iter = part.feed_key.find(name); + for (const auto& name_index : feeds) { + const auto iter = part.feed_key.find(name_index.first.ToString()); if (iter == part.feed_key.end()) { // The provided feed must be for a different partition. continue; } const string& key = iter->second; - auto feeds_iter = feeds.find(name); - if (feeds_iter == feeds.end()) { - return errors::InvalidArgument("No feed is provided for feed=", name, - ", key=", key); - } else if (feeds_iter->second != static_cast<size_t>(i)) { - return errors::Internal("Cannot find feed named \"", name, - " in request."); - } - TF_RETURN_IF_ERROR(c->req->AddSendFromRunStepRequest(req, i, key)); + TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(), + name_index.second, key)); } // TODO(suharshs): Make a map from feed to fetch_key to make this faster. // For now, we just iterate through partitions to find the matching key. - for (int i = 0; static_cast<size_t>(i) < req.num_fetches(); ++i) { - const string& req_fetch = req.fetch_name(i); + for (const string& req_fetch : fetches) { for (const auto& key_fetch : part.key_fetch) { if (key_fetch.second == req_fetch) { c->req->add_recv_key(key_fetch.first); @@ -586,9 +622,13 @@ Status MasterSession::ReffedClientGraph::RunPartitions( for (const auto& feed_key : part.feed_key) { const string& feed = feed_key.first; const string& key = feed_key.second; - const int64 feed_index = feeds[feed]; + auto iter = feeds.find(feed); + if (iter == feeds.end()) { + return errors::Internal("No feed index found for feed: ", feed); + } + const int64 feed_index = iter->second; TF_RETURN_IF_ERROR( - c->req->AddSendFromRunStepRequest(req, feed_index, key)); + AddSendFromClientRequest(req, c->req.get(), feed_index, key)); } for (const auto& key_fetch : part.key_fetch) { const string& key = key_fetch.first; @@ -622,50 +662,115 @@ Status MasterSession::ReffedClientGraph::RunPartitions( } else { return errors::Cancelled("Step was cancelled"); } + TF_RETURN_IF_ERROR(calls.status()); - // Collects fetches. - Status status = calls.status(); - if (status.ok()) { - for (int i = 0; i < num; ++i) { - const Part& part = partitions_[i]; - MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get(); - for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) { - auto iter = part.key_fetch.find(run_graph_resp->recv_key(j)); - if (iter == part.key_fetch.end()) { - status.Update(errors::Internal("Unexpected fetch key: ", - run_graph_resp->recv_key(j))); - break; - } - const string& fetch = iter->second; - status.Update( - resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j)); - if (!status.ok()) { - break; - } + // Collects fetches and metadata. + Status status; + for (int i = 0; i < num; ++i) { + const Part& part = partitions_[i]; + MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get(); + for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) { + auto iter = part.key_fetch.find(run_graph_resp->recv_key(j)); + if (iter == part.key_fetch.end()) { + status.Update(errors::Internal("Unexpected fetch key: ", + run_graph_resp->recv_key(j))); + break; } - if (pss->collect_timeline) { - pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats()); + const string& fetch = iter->second; + status.Update( + resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j)); + if (!status.ok()) { + break; } - if (pss->collect_costs) { - CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph(); - for (int j = 0; j < cost_graph->node_size(); ++j) { - resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap( - cost_graph->mutable_node(j)); - } + } + if (pss->collect_timeline) { + pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats()); + } + if (pss->collect_costs) { + CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph(); + for (int j = 0; j < cost_graph->node_size(); ++j) { + resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap( + cost_graph->mutable_node(j)); } - if (pss->collect_partition_graphs) { - protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = - resp->mutable_metadata()->mutable_partition_graphs(); - for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) { - partition_graph_defs->Add()->Swap( - run_graph_resp->mutable_partition_graph(i)); - } + } + if (pss->collect_partition_graphs) { + protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = + resp->mutable_metadata()->mutable_partition_graphs(); + for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) { + partition_graph_defs->Add()->Swap( + run_graph_resp->mutable_partition_graph(i)); } } } return status; } +Status MasterSession::ReffedClientGraph::RunPartitions( + const MasterEnv* env, int64 step_id, int64 execution_count, + PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req, + MutableRunStepResponseWrapper* resp, CancellationManager* cm, + const bool is_last_partial_run) { + VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " + << execution_count; + // Maps the names of fed tensors to their index in `req`. + std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3); + for (size_t i = 0; i < req.num_feeds(); ++i) { + if (!feeds.insert({req.feed_name(i), i}).second) { + return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i)); + } + } + + std::vector<string> fetches; + fetches.reserve(req.num_fetches()); + for (size_t i = 0; i < req.num_fetches(); ++i) { + fetches.push_back(req.fetch_name(i)); + } + + return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss, + call_opts, req, resp, cm, is_last_partial_run); +} + +Status MasterSession::ReffedClientGraph::RunPartitions( + const MasterEnv* env, int64 step_id, int64 execution_count, + PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req, + RunCallableResponse* resp, CancellationManager* cm) { + VLOG(2) << "RunPartitions step_id " << step_id << " execution_count " + << execution_count; + // Maps the names of fed tensors to their index in `req`. + std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3); + for (size_t i = 0; i < callable_opts_.feed_size(); ++i) { + if (!feeds.insert({callable_opts_.feed(i), i}).second) { + // MakeCallable will fail if there are two feeds with the same name. + return errors::Internal("Duplicated feeds in callable: ", + callable_opts_.feed(i)); + } + } + + // Create a wrapped response object to collect the fetched values and + // rearrange them for the RunCallableResponse. + RunCallableResponseWrapper wrapped_resp; + wrapped_resp.resp = resp; + + TF_RETURN_IF_ERROR(RunPartitionsHelper( + feeds, callable_opts_.fetch(), env, step_id, execution_count, pss, + call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */)); + + // Collects fetches. + // TODO(b/74355905): Add a specialized implementation that avoids + // copying the tensor into the RunCallableResponse when at least + // two of the {client, master, worker} are in the same process. + for (const string& fetch : callable_opts_.fetch()) { + TensorProto* fetch_proto = resp->mutable_fetch()->Add(); + auto iter = wrapped_resp.fetch_key_to_protos.find(fetch); + if (iter == wrapped_resp.fetch_key_to_protos.end()) { + return errors::Internal("Worker did not return a value for fetch: ", + fetch); + } + fetch_proto->Swap(&iter->second); + } + return Status::OK(); +} + namespace { class CleanupBroadcastHelper { @@ -1266,15 +1371,11 @@ WorkerCacheInterface* MasterSession::get_worker_cache() const { return env_->worker_cache; } -Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, - ReffedClientGraph** rcg, bool is_partial) { +Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial, + ReffedClientGraph** out_rcg, int64* out_count) { const uint64 hash = HashBuildGraphOptions(opts); { mutex_lock l(mu_); - // Keep track of how many times this subgraph has been executed in - // this session. - int64* c = &subgraph_execution_counts_[hash]; - *count = (*c)++; // TODO(suharshs): We cache partial run graphs and run graphs separately // because there is preprocessing that needs to only be run for partial // run calls. @@ -1296,8 +1397,9 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, iter = m->insert({hash, entry}).first; VLOG(1) << "Preparing to execute new graph"; } - *rcg = iter->second; - (*rcg)->Ref(); + *out_rcg = iter->second; + (*out_rcg)->Ref(); + *out_count = (*out_rcg)->get_and_increment_execution_count(); } return Status::OK(); } @@ -1316,6 +1418,12 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, rcg_map->clear(); } +namespace { +uint64 MakeStepId() { + return (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); +} +} // namespace + Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, PartialRunSetupResponse* resp) { std::vector<string> inputs, outputs, targets; @@ -1332,15 +1440,15 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req, string handle = std::to_string(partial_run_handle_counter_.fetch_add(1)); ReffedClientGraph* rcg = nullptr; - int64 count = 0; // Prepare. BuildGraphOptions opts; BuildBuildGraphOptions(*req, &opts); - TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true)); + int64 count; + TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count)); // Keeps the highest 8 bits 0x01: we reserve some bits of the // step_id for future use. - uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); + const uint64 step_id = MakeStepId(); TRACEPRINTF("stepid %llu", step_id); rcg->Ref(); @@ -1585,6 +1693,73 @@ Status MasterSession::CreateDebuggerState( return Status::OK(); } +void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg, + const RunOptions& run_options, + uint64 step_id, int64 count, + PerStepState* out_pss, + std::unique_ptr<ProfileHandler>* out_ph) { + out_pss->collect_timeline = + run_options.trace_level() == RunOptions::FULL_TRACE; + out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE; + out_pss->report_tensor_allocations_upon_oom = + run_options.report_tensor_allocations_upon_oom(); + // Build the cost model every 'build_cost_model_every' steps after skipping an + // initial 'build_cost_model_after' steps. + const int64 build_cost_model_after = + session_opts_.config.graph_options().build_cost_model_after(); + const int64 build_cost_model_every = + session_opts_.config.graph_options().build_cost_model(); + out_pss->collect_costs = + build_cost_model_every > 0 && + ((count + 1 - build_cost_model_after) % build_cost_model_every == 0); + out_pss->collect_partition_graphs = run_options.output_partition_graphs(); + + *out_ph = rcg->GetProfileHandler(step_id, count, run_options); + if (*out_ph) { + out_pss->collect_timeline = true; + out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs(); + } +} + +Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg, + uint64 step_id, + const RunOptions& run_options, + PerStepState* pss, + const std::unique_ptr<ProfileHandler>& ph, + const Status& run_status, + RunMetadata* out_run_metadata) { + Status s = run_status; + if (s.ok()) { + pss->end_micros = Env::Default()->NowMicros(); + + // Schedule post-processing and cleanup to be done asynchronously. + rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata); + } else if (errors::IsCancelled(s)) { + mutex_lock l(mu_); + if (closed_) { + if (garbage_collected_) { + s = errors::Cancelled( + "Step was cancelled because the session was garbage collected due " + "to inactivity."); + } else { + s = errors::Cancelled( + "Step was cancelled by an explicit call to `Session::Close()`."); + } + } + } + Ref(); + rcg->Ref(); + rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) { + if (!s.ok()) { + LOG(ERROR) << "Cleanup partition error: " << s; + } + rcg->Unref(); + MarkRunCompletion(); + Unref(); + }); + return s; +} + Status MasterSession::DoRunWithLocalExecution( CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp) { @@ -1597,8 +1772,8 @@ Status MasterSession::DoRunWithLocalExecution( BuildGraphOptions bgopts; BuildBuildGraphOptions(req, &bgopts); ReffedClientGraph* rcg = nullptr; - int64 count = 0; - TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false)); + int64 count; + TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count)); // Unref "rcg" when out of scope. core::ScopedUnref unref(rcg); @@ -1614,64 +1789,133 @@ Status MasterSession::DoRunWithLocalExecution( // Keeps the highest 8 bits 0x01: we reserve some bits of the // step_id for future use. - const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56); + const uint64 step_id = MakeStepId(); TRACEPRINTF("stepid %llu", step_id); - pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE; - pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE; - pss.report_tensor_allocations_upon_oom = - req.options().report_tensor_allocations_upon_oom(); - // Build the cost model every 'build_cost_model_every' steps after skipping an - // initial 'build_cost_model_after' steps. - const int64 build_cost_model_after = - session_opts_.config.graph_options().build_cost_model_after(); - const int64 build_cost_model_every = - session_opts_.config.graph_options().build_cost_model(); - pss.collect_costs = - build_cost_model_every > 0 && - ((count + 1 - build_cost_model_after) % build_cost_model_every == 0); - pss.collect_partition_graphs = req.options().output_partition_graphs(); + std::unique_ptr<ProfileHandler> ph; + FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph); - std::unique_ptr<ProfileHandler> ph = - rcg->GetProfileHandler(step_id, count, req.options()); - if (ph) { - pss.collect_timeline = true; - pss.collect_rpcs = ph->should_collect_rpcs(); + Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp, + &cancellation_manager_, false); + cleanup.release(); // MarkRunCompletion called in PostRunCleanup(). + return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s, + resp->mutable_metadata()); +} + +Status MasterSession::MakeCallable(const MakeCallableRequest& req, + MakeCallableResponse* resp) { + UpdateLastAccessTime(); + + BuildGraphOptions opts; + opts.callable_options = req.options(); + opts.use_function_convention = false; + + ReffedClientGraph* callable; + + { + mutex_lock l(mu_); + if (closed_) { + return errors::FailedPrecondition("Session is closed."); + } + std::unique_ptr<ClientGraph> client_graph; + TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph)); + callable = new ReffedClientGraph(handle_, opts, std::move(client_graph), + session_opts_, stats_publisher_factory_, + false /* is_partial */, get_worker_cache(), + !should_delete_worker_sessions_); + } + + Status s = BuildAndRegisterPartitions(callable); + if (!s.ok()) { + callable->Unref(); + return s; } + uint64 handle; + { + mutex_lock l(mu_); + handle = next_callable_handle_++; + callables_[handle] = callable; + } + + resp->set_handle(handle); + return Status::OK(); +} + +Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, + const RunCallableRequest& req, + RunCallableResponse* resp) { + VLOG(2) << "DoRunCallable req: " << req.DebugString(); + PerStepState pss; + pss.start_micros = Env::Default()->NowMicros(); + auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); }); + + // Prepare. + int64 count = rcg->get_and_increment_execution_count(); + + // Keeps the highest 8 bits 0x01: we reserve some bits of the + // step_id for future use. + const uint64 step_id = MakeStepId(); + TRACEPRINTF("stepid %llu", step_id); + + const RunOptions& run_options = rcg->callable_options().run_options(); + + if (run_options.timeout_in_ms() != 0) { + opts->SetTimeout(run_options.timeout_in_ms()); + } + + std::unique_ptr<ProfileHandler> ph; + FillPerStepState(rcg, run_options, step_id, count, &pss, &ph); Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp, - &cancellation_manager_, false); - if (s.ok()) { - pss.end_micros = Env::Default()->NowMicros(); + &cancellation_manager_); + cleanup.release(); // MarkRunCompletion called in PostRunCleanup(). + return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s, + resp->mutable_metadata()); +} - // Schedule post-processing and cleanup to be done asynchronously. - rcg->ProcessStats(step_id, &pss, ph.get(), req.options(), - resp->mutable_metadata()); - } else if (errors::IsCancelled(s)) { +Status MasterSession::RunCallable(CallOptions* opts, + const RunCallableRequest& req, + RunCallableResponse* resp) { + UpdateLastAccessTime(); + ReffedClientGraph* callable; + { mutex_lock l(mu_); if (closed_) { - if (garbage_collected_) { - s = errors::Cancelled( - "Step was cancelled because the session was garbage collected due " - "to inactivity."); - } else { - s = errors::Cancelled( - "Step was cancelled by an explicit call to `Session::Close()`."); - } + return errors::FailedPrecondition("Session is closed."); + } + int64 handle = req.handle(); + if (handle >= next_callable_handle_) { + return errors::InvalidArgument("No such callable handle: ", handle); + } + auto iter = callables_.find(req.handle()); + if (iter == callables_.end()) { + return errors::InvalidArgument( + "Attempted to run callable after handle was released: ", handle); } + callable = iter->second; + callable->Ref(); + ++num_running_; } - Ref(); - rcg->Ref(); - cleanup.release(); // MarkRunCompletion called in done closure. - rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) { - if (!s.ok()) { - LOG(ERROR) << "Cleanup partition error: " << s; + core::ScopedUnref unref_callable(callable); + return DoRunCallable(opts, callable, req, resp); +} + +Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req, + ReleaseCallableResponse* resp) { + UpdateLastAccessTime(); + ReffedClientGraph* to_unref = nullptr; + { + mutex_lock l(mu_); + auto iter = callables_.find(req.handle()); + if (iter != callables_.end()) { + to_unref = iter->second; + callables_.erase(iter); } - rcg->Unref(); - MarkRunCompletion(); - Unref(); - }); - return s; + } + if (to_unref != nullptr) { + to_unref->Unref(); + } + return Status::OK(); } Status MasterSession::Close() { @@ -1688,6 +1932,7 @@ Status MasterSession::Close() { } ClearRunsTable(&to_unref, &run_graphs_); ClearRunsTable(&to_unref, &partial_run_graphs_); + ClearRunsTable(&to_unref, &callables_); } for (ReffedClientGraph* rcg : to_unref) rcg->Unref(); if (should_delete_worker_sessions_) { diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 4bd4e1367a..a05419904f 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -89,6 +89,15 @@ class MasterSession : public core::RefCounted { Status ListDevices(ListDevicesResponse* resp) const; + Status MakeCallable(const MakeCallableRequest& req, + MakeCallableResponse* resp); + + Status RunCallable(CallOptions* opts, const RunCallableRequest& req, + RunCallableResponse* resp); + + Status ReleaseCallable(const ReleaseCallableRequest& req, + ReleaseCallableResponse* resp); + // Close this session and delete "*this". Returns OK if all known // states are cleanup successfully. // @@ -140,6 +149,8 @@ class MasterSession : public core::RefCounted { typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap; RCGMap run_graphs_ GUARDED_BY(mu_); RCGMap partial_run_graphs_ GUARDED_BY(mu_); + int64 next_callable_handle_ GUARDED_BY(mu_) = 0; + RCGMap callables_ GUARDED_BY(mu_); struct PerStepState { bool collect_costs = false; @@ -205,15 +216,28 @@ class MasterSession : public core::RefCounted { bool should_delete_worker_sessions_ = false; Status DeleteWorkerSessions(); - Status StartStep(const BuildGraphOptions& opts, int64* count, - ReffedClientGraph** graph, bool is_partial); + Status StartStep(const BuildGraphOptions& opts, bool is_partial, + ReffedClientGraph** out_rcg, int64* out_count); void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref, RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_); + void FillPerStepState(MasterSession::ReffedClientGraph* rcg, + const RunOptions& run_options, uint64 step_id, + int64 count, PerStepState* out_pss, + std::unique_ptr<ProfileHandler>* out_ph); Status DoRunWithLocalExecution(CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp); Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req, MutableRunStepResponseWrapper* resp); + Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, + const RunCallableRequest& req, + RunCallableResponse* resp); + Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id, + const RunOptions& run_options, PerStepState* pss, + const std::unique_ptr<ProfileHandler>& ph, + const Status& run_status, + RunMetadata* out_run_metadata); + void MarkRunCompletion(); void UpdateLastAccessTime(); diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index 66ebb3080a..18668b44d3 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -326,6 +326,20 @@ Status InMemoryRunGraphRequest::AddSendFromRunStepRequest( return Status::OK(); } +// TODO(b/74355905): Add a specialized implementation that avoids +// copying the tensor when at least two of the {client, master, +// worker} are in the same process. +Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest( + const RunCallableRequest& run_callable_request, size_t i, + const string& send_key) { + Tensor tensor; + if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) { + return errors::InvalidArgument("Invalid TensorProto for feed value ", i); + } + sends_.emplace_back(send_key, std::move(tensor)); + return Status::OK(); +} + size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); } const string& InMemoryRunGraphRequest::recv_key(size_t i) const { @@ -439,6 +453,18 @@ Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest( return Status::OK(); } +// TODO(b/74355905): Add a specialized implementation that avoids +// copying the tensor when at least two of the {client, master, +// worker} are in the same process. +Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest( + const RunCallableRequest& run_callable_request, size_t i, + const string& send_key) { + NamedTensorProto* send = request_.add_send(); + send->set_name(send_key); + *send->mutable_tensor() = run_callable_request.feed(i); + return Status::OK(); +} + size_t MutableProtoRunGraphRequest::num_recvs() const { return request_.recv_key_size(); } diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index 79fa6f926e..1f7cdb98a4 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -302,6 +302,9 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper { virtual Status AddSendFromRunStepRequest( const RunStepRequestWrapper& run_step_request, size_t i, const string& send_key) = 0; + virtual Status AddSendFromRunCallableRequest( + const RunCallableRequest& run_callable_request, size_t i, + const string& send_key) = 0; virtual void add_recv_key(const string& recv_key) = 0; virtual void set_is_partial(bool is_partial) = 0; @@ -334,6 +337,9 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper { Status AddSendFromRunStepRequest( const RunStepRequestWrapper& run_step_request, size_t i, const string& send_key) override; + Status AddSendFromRunCallableRequest( + const RunCallableRequest& run_callable_request, size_t i, + const string& send_key) override; void add_recv_key(const string& recv_key) override; void set_is_partial(bool is_partial) override; void set_is_last_partial_run(bool is_last_partial_run) override; @@ -385,6 +391,9 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper { Status AddSendFromRunStepRequest( const RunStepRequestWrapper& run_step_request, size_t i, const string& send_key) override; + Status AddSendFromRunCallableRequest( + const RunCallableRequest& run_callable_request, size_t i, + const string& send_key) override; void add_recv_key(const string& recv_key) override; void set_is_partial(bool is_partial) override; void set_is_last_partial_run(bool is_last_partial_run) override; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc index 63745e8ebd..23968e24c8 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc @@ -111,6 +111,11 @@ class GrpcMasterService : public AsyncServiceInterface { ENQUEUE_REQUEST(CloseSession, false); ENQUEUE_REQUEST(ListDevices, false); ENQUEUE_REQUEST(Reset, false); + ENQUEUE_REQUEST(MakeCallable, false); + for (int i = 0; i < 100; ++i) { + ENQUEUE_REQUEST(RunCallable, true); + } + ENQUEUE_REQUEST(ReleaseCallable, false); void* tag; bool ok; @@ -236,6 +241,47 @@ class GrpcMasterService : public AsyncServiceInterface { }); ENQUEUE_REQUEST(Reset, false); } + + // RPC handler for making a callable. + void MakeCallableHandler( + MasterCall<MakeCallableRequest, MakeCallableResponse>* call) { + master_impl_->MakeCallable(&call->request, &call->response, + [call](const Status& status) { + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(MakeCallable, false); + } + + // RPC handler for running a callable. + void RunCallableHandler( + MasterCall<RunCallableRequest, RunCallableResponse>* call) { + auto* trace = TraceRpc("RunCallable/Server", call->client_metadata()); + CallOptions* call_opts = new CallOptions; + // The timeout may be overridden by a non-zero timeout in the + // callable's `RunOptions`; this overriding will happen inside the + // `MasterSession` implementation. + call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms()); + call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); }); + master_impl_->RunCallable(call_opts, &call->request, &call->response, + [call, call_opts, trace](const Status& status) { + call->ClearCancelCallback(); + delete call_opts; + delete trace; + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(RunCallable, false); + } + + // RPC handler for making a callable. + void ReleaseCallableHandler( + MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) { + master_impl_->ReleaseCallable(&call->request, &call->response, + [call](const Status& status) { + call->SendResponse(ToGrpcStatus(status)); + }); + ENQUEUE_REQUEST(ReleaseCallable, false); + } + #undef ENQUEUE_REQUEST // Start tracing, including the ID attached to the RPC. diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc index e2016e824c..c832adbbbf 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc @@ -36,6 +36,9 @@ static const char* grpcMasterService_method_names[] = { "/tensorflow.MasterService/CloseSession", "/tensorflow.MasterService/ListDevices", "/tensorflow.MasterService/Reset", + "/tensorflow.MasterService/MakeCallable", + "/tensorflow.MasterService/RunCallable", + "/tensorflow.MasterService/ReleaseCallable", }; std::unique_ptr<MasterService::Stub> MasterService::NewStub( @@ -64,7 +67,14 @@ MasterService::Stub::Stub( rpcmethod_ListDevices_(grpcMasterService_method_names[5], ::grpc::internal::RpcMethod::NORMAL_RPC, channel), rpcmethod_Reset_(grpcMasterService_method_names[6], - ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {} + ::grpc::internal::RpcMethod::NORMAL_RPC, channel), + rpcmethod_MakeCallable_(grpcMasterService_method_names[7], + ::grpc::internal::RpcMethod::NORMAL_RPC, channel), + rpcmethod_RunCallable_(grpcMasterService_method_names[8], + ::grpc::internal::RpcMethod::NORMAL_RPC, channel), + rpcmethod_ReleaseCallable_(grpcMasterService_method_names[9], + ::grpc::internal::RpcMethod::NORMAL_RPC, + channel) {} ::grpc::Status MasterService::Stub::CreateSession( ::grpc::ClientContext* context, const CreateSessionRequest& request, @@ -115,8 +125,29 @@ MasterService::Stub::Stub( context, request, response); } +::grpc::Status MasterService::Stub::MakeCallable( + ::grpc::ClientContext* context, const MakeCallableRequest& request, + MakeCallableResponse* response) { + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_MakeCallable_, context, request, response); +} + +::grpc::Status MasterService::Stub::RunCallable( + ::grpc::ClientContext* context, const RunCallableRequest& request, + RunCallableResponse* response) { + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_RunCallable_, context, request, response); +} + +::grpc::Status MasterService::Stub::ReleaseCallable( + ::grpc::ClientContext* context, const ReleaseCallableRequest& request, + ReleaseCallableResponse* response) { + return ::grpc::internal::BlockingUnaryCall( + channel_.get(), rpcmethod_ReleaseCallable_, context, request, response); +} + MasterService::AsyncService::AsyncService() { - for (int i = 0; i < 7; ++i) { + for (int i = 0; i < 10; ++i) { AddMethod(new ::grpc::internal::RpcServiceMethod( grpcMasterService_method_names[i], ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr)); diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h index 6ae94b7441..3c382738c4 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h @@ -79,6 +79,15 @@ class MasterService final { virtual ::grpc::Status Reset(::grpc::ClientContext* context, const ResetRequest& request, ResetResponse* response) = 0; + virtual ::grpc::Status MakeCallable(::grpc::ClientContext* context, + const MakeCallableRequest& request, + MakeCallableResponse* response) = 0; + virtual ::grpc::Status RunCallable(::grpc::ClientContext* context, + const RunCallableRequest& request, + RunCallableResponse* response) = 0; + virtual ::grpc::Status ReleaseCallable( + ::grpc::ClientContext* context, const ReleaseCallableRequest& request, + ReleaseCallableResponse* response) = 0; }; class Stub final : public StubInterface { public: @@ -104,6 +113,15 @@ class MasterService final { ::grpc::Status Reset(::grpc::ClientContext* context, const ResetRequest& request, ResetResponse* response) override; + ::grpc::Status MakeCallable(::grpc::ClientContext* context, + const MakeCallableRequest& request, + MakeCallableResponse* response) override; + ::grpc::Status RunCallable(::grpc::ClientContext* context, + const RunCallableRequest& request, + RunCallableResponse* response) override; + ::grpc::Status ReleaseCallable(::grpc::ClientContext* context, + const ReleaseCallableRequest& request, + ReleaseCallableResponse* response) override; private: std::shared_ptr< ::grpc::ChannelInterface> channel_; @@ -114,6 +132,9 @@ class MasterService final { const ::grpc::internal::RpcMethod rpcmethod_CloseSession_; const ::grpc::internal::RpcMethod rpcmethod_ListDevices_; const ::grpc::internal::RpcMethod rpcmethod_Reset_; + const ::grpc::internal::RpcMethod rpcmethod_MakeCallable_; + const ::grpc::internal::RpcMethod rpcmethod_RunCallable_; + const ::grpc::internal::RpcMethod rpcmethod_ReleaseCallable_; }; static std::unique_ptr<Stub> NewStub( const std::shared_ptr< ::grpc::ChannelInterface>& channel, @@ -179,6 +200,30 @@ class MasterService final { ::grpc::Service::RequestAsyncUnary(6, context, request, response, new_call_cq, notification_cq, tag); } + void RequestMakeCallable( + ::grpc::ServerContext* context, MakeCallableRequest* request, + ::grpc::ServerAsyncResponseWriter<MakeCallableResponse>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(7, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestRunCallable( + ::grpc::ServerContext* context, RunCallableRequest* request, + ::grpc::ServerAsyncResponseWriter<RunCallableResponse>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(8, context, request, response, + new_call_cq, notification_cq, tag); + } + void RequestReleaseCallable( + ::grpc::ServerContext* context, ReleaseCallableRequest* request, + ::grpc::ServerAsyncResponseWriter<ReleaseCallableResponse>* response, + ::grpc::CompletionQueue* new_call_cq, + ::grpc::ServerCompletionQueue* notification_cq, void* tag) { + ::grpc::Service::RequestAsyncUnary(9, context, request, response, + new_call_cq, notification_cq, tag); + } }; }; diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc index 1088e9be66..1b92a79a67 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc @@ -95,6 +95,28 @@ class GrpcRemoteMaster : public MasterInterface { &MasterServiceStub::Reset); } + Status MakeCallable(CallOptions* call_options, + const MakeCallableRequest* request, + MakeCallableResponse* response) override { + ::grpc::ClientContext ctx; + return Call(&ctx, call_options, request, response, + &MasterServiceStub::MakeCallable); + } + Status RunCallable(CallOptions* call_options, + const RunCallableRequest* request, + RunCallableResponse* response) override { + ::grpc::ClientContext ctx; + return Call(&ctx, call_options, request, response, + &MasterServiceStub::RunCallable); + } + Status ReleaseCallable(CallOptions* call_options, + const ReleaseCallableRequest* request, + ReleaseCallableResponse* response) override { + ::grpc::ClientContext ctx; + return Call(&ctx, call_options, request, response, + &MasterServiceStub::ReleaseCallable); + } + private: // Start tracing, attaching a unique ID to both the trace and the RPC. port::Tracing::TraceMe TraceRpc(StringPiece name, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc index 3e79a40683..fd1c150fa7 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc @@ -91,6 +91,15 @@ void ReEncodeConsts(GraphDef* gdef) { } } // namespace +Status GrpcSession::Handle(string* out_handle) { + mutex_lock l(mu_); + if (handle_.empty()) { + return errors::InvalidArgument("A session is not created yet...."); + } + *out_handle = handle_; + return Status::OK(); +} + Status GrpcSession::CreateImpl(CallOptions* call_options, const GraphDef& graph) { { @@ -274,14 +283,9 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs, Status GrpcSession::RunProto(CallOptions* call_options, MutableRunStepRequestWrapper* req, MutableRunStepResponseWrapper* resp) { - { - mutex_lock l(mu_); - if (handle_.empty()) { - return errors::InvalidArgument("A session is not created yet...."); - } - - req->set_session_handle(handle_); - } + string handle; + TF_RETURN_IF_ERROR(Handle(&handle)); + req->set_session_handle(handle); return master_->RunStep(call_options, req, resp); } @@ -293,14 +297,7 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names, PartialRunSetupRequest req; PartialRunSetupResponse resp; CallOptions call_options; - { - mutex_lock l(mu_); - if (handle_.empty()) { - return errors::InvalidArgument("A session is not created yet...."); - } - - req.set_session_handle(handle_); - } + TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle())); for (const string& feed : input_names) { req.add_feed(feed); } @@ -400,6 +397,55 @@ Status GrpcSession::Reset(const SessionOptions& options, return ret; } +Status GrpcSession::MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) { + MakeCallableRequest req; + TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle())); + *req.mutable_options() = callable_options; + MakeCallableResponse resp; + CallOptions call_options; + call_options.SetTimeout(options_.config.operation_timeout_in_ms()); + TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp)); + *out_handle = resp.handle(); + return Status::OK(); +} + +Status GrpcSession::RunCallable(CallableHandle handle, + const std::vector<Tensor>& feed_tensors, + std::vector<Tensor>* fetch_tensors, + RunMetadata* run_metadata) { + RunCallableRequest req; + TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle())); + req.set_handle(handle); + for (const Tensor& feed : feed_tensors) { + feed.AsProtoTensorContent(req.mutable_feed()->Add()); + } + + RunCallableResponse resp; + CallOptions call_options; + call_options.SetTimeout(options_.config.operation_timeout_in_ms()); + TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp)); + for (const TensorProto& fetch : resp.fetch()) { + Tensor fetch_tensor; + if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) { + return errors::Internal( + "Could not parse fetched tensor data in response from master."); + } + fetch_tensors->push_back(std::move(fetch_tensor)); + } + return Status::OK(); +} + +Status GrpcSession::ReleaseCallable(CallableHandle handle) { + ReleaseCallableRequest req; + TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle())); + req.set_handle(handle); + ReleaseCallableResponse resp; + CallOptions call_options; + call_options.SetTimeout(options_.config.operation_timeout_in_ms()); + return master_->ReleaseCallable(&call_options, &req, &resp); +} + class GrpcSessionFactory : public SessionFactory { public: bool AcceptsOptions(const SessionOptions& options) override { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h index d87956a135..63795117f9 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h @@ -82,20 +82,27 @@ class GrpcSession : public Session { Status Close() override; // NOTE: This API is still experimental and may change. - ::tensorflow::Status PRunSetup(const std::vector<string>& input_names, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - string* handle) override; + Status PRunSetup(const std::vector<string>& input_names, + const std::vector<string>& output_names, + const std::vector<string>& target_nodes, + string* handle) override; // NOTE: This API is still experimental and may change. - ::tensorflow::Status PRun( - const string& handle, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_names, - std::vector<Tensor>* outputs) override; + Status PRun(const string& handle, + const std::vector<std::pair<string, Tensor> >& inputs, + const std::vector<string>& output_names, + std::vector<Tensor>* outputs) override; Status ListDevices(std::vector<DeviceAttributes>* response) override; + Status MakeCallable(const CallableOptions& callable_options, + CallableHandle* out_handle) override; + Status RunCallable(CallableHandle handle, + const std::vector<Tensor>& feed_tensors, + std::vector<Tensor>* fetch_tensors, + RunMetadata* run_metadata) override; + Status ReleaseCallable(CallableHandle handle) override; + protected: // Takes ownership of `*master`. void SetRemoteMaster(std::unique_ptr<MasterInterface> master); @@ -111,6 +118,8 @@ class GrpcSession : public Session { // The current version of the graph. int64 current_graph_version_ GUARDED_BY(mu_); + Status Handle(string* out_handle) LOCKS_EXCLUDED(mu_); + Status RunHelper(const RunOptions& run_options, const std::vector<std::pair<string, Tensor> >& inputs, const std::vector<string>& output_tensor_names, diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc index 335c3febe2..45b15a54a2 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc @@ -120,6 +120,49 @@ TEST(GrpcSessionTest, BasicNonProtoAPI) { } } +TEST(GrpcSessionTest, BasicCallable) { + GraphDef graph; + string node_names[3]; + // c = a * b + CreateGraphDef(&graph, node_names); + + std::unique_ptr<test::TestCluster> cluster; + TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster)); + + std::unique_ptr<Session> session( + NewRemote(Options(cluster->targets()[0], 1))); + ASSERT_TRUE(session != nullptr); + + for (int iters = 0; iters < 25; ++iters) { + TF_CHECK_OK(session->Create(graph)); + { + // Just run to target node + CallableOptions opts; + opts.add_target(node_names[2]); + Session::CallableHandle handle; + TF_CHECK_OK(session->MakeCallable(opts, &handle)); + TF_CHECK_OK(session->RunCallable(handle, {}, nullptr, nullptr)); + TF_CHECK_OK(session->ReleaseCallable(handle)); + } + { + // Run to a target node and a real tensor + CallableOptions opts; + opts.add_target(node_names[1]); + opts.add_fetch(node_names[2] + ":0"); + Session::CallableHandle handle; + TF_CHECK_OK(session->MakeCallable(opts, &handle)); + std::vector<Tensor> outputs; + TF_CHECK_OK(session->RunCallable(handle, {}, &outputs, nullptr)); + ASSERT_EQ(1, outputs.size()); + ASSERT_TRUE(outputs[0].IsInitialized()); + ASSERT_EQ(4.0, outputs[0].flat<float>()(0)); + TF_CHECK_OK(session->ReleaseCallable(handle)); + } + + TF_CHECK_OK(session->Close()); + } +} + TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) { GraphDef graph; string node_names[3]; diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto index 0437cb1b83..96c91536f7 100644 --- a/tensorflow/core/protobuf/master.proto +++ b/tensorflow/core/protobuf/master.proto @@ -23,6 +23,7 @@ option java_package = "org.tensorflow.distruntime"; import "tensorflow/core/framework/device_attributes.proto"; import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/tensor.proto"; import "tensorflow/core/lib/core/error_codes.proto"; import "tensorflow/core/protobuf/config.proto"; import "tensorflow/core/protobuf/named_tensor.proto"; @@ -264,3 +265,70 @@ message ListDevicesResponse { repeated DeviceAttributes local_device = 1; repeated DeviceAttributes remote_device = 2; } + +//////////////////////////////////////////////////////////////////////////////// +// +// MakeCallable method request/response protos. +// +//////////////////////////////////////////////////////////////////////////////// + +message MakeCallableRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; + + // Options that define the behavior of the created callable. + CallableOptions options = 2; +} + +message MakeCallableResponse { + // A handle to the created callable. + int64 handle = 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// RunCallable method request/response protos. +// +//////////////////////////////////////////////////////////////////////////////// + +message RunCallableRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; + // REQUIRED: handle must be returned by a MakeCallable call to the same + // master service. + int64 handle = 2; + + // Values of the tensors passed as arguments to the callable, in the order + // defined in the CallableOptions.feed field passed to MakeCallable. + repeated TensorProto feed = 3; +} + +message RunCallableResponse { + // Values of the tensors returned by the callable, in the order defined in the + // CallableOptions.fetch field passed to MakeCallable. + repeated TensorProto fetch = 1; + + // Returned metadata if requested in the options. + RunMetadata metadata = 2; +} + +//////////////////////////////////////////////////////////////////////////////// +// +// ReleaseCallable method request/response protos. +// +//////////////////////////////////////////////////////////////////////////////// + +message ReleaseCallableRequest { + // REQUIRED: session_handle must be returned by a CreateSession call + // to the same master service. + string session_handle = 1; + + // REQUIRED: handle must be returned by a MakeCallable call to the same + // master service. + int64 handle = 2; +} + +message ReleaseCallableResponse { +} diff --git a/tensorflow/core/protobuf/master_service.proto b/tensorflow/core/protobuf/master_service.proto index 771c80562a..1170611f37 100644 --- a/tensorflow/core/protobuf/master_service.proto +++ b/tensorflow/core/protobuf/master_service.proto @@ -107,4 +107,13 @@ service MasterService { // will no longer affect fresh ones via the resources in containers listed in // the ResetRequest. See ResetRequest for more details. rpc Reset(ResetRequest) returns (ResetResponse); + + // Registers a callable for execution with RunCallable. + rpc MakeCallable(MakeCallableRequest) returns (MakeCallableResponse); + + // Executes a callable registered with MakeCallable. + rpc RunCallable(RunCallableRequest) returns (RunCallableResponse); + + // Frees resources associated with a callable registered with MakeCallable. + rpc ReleaseCallable(ReleaseCallableRequest) returns (ReleaseCallableResponse); } |