aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-04-06 17:39:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 18:18:06 -0700
commit470cc0f75108e68965f89026399f7b3a7a08196b (patch)
treee4997f253781b176976f9baf9949d0a0e3751c8a
parent38d1ac1e4f5b2a6e88eee43d332292898e0afc41 (diff)
Add remote session support for the MakeCallable API.
PiperOrigin-RevId: 191964391
-rw-r--r--tensorflow/core/distributed_runtime/local_master.cc41
-rw-r--r--tensorflow/core/distributed_runtime/local_master.h10
-rw-r--r--tensorflow/core/distributed_runtime/master.cc51
-rw-r--r--tensorflow/core/distributed_runtime/master.h7
-rw-r--r--tensorflow/core/distributed_runtime/master_interface.h10
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc499
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h28
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.cc26
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h9
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc46
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc35
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h45
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc22
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.cc78
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session.h27
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc43
-rw-r--r--tensorflow/core/protobuf/master.proto68
-rw-r--r--tensorflow/core/protobuf/master_service.proto9
18 files changed, 898 insertions, 156 deletions
diff --git a/tensorflow/core/distributed_runtime/local_master.cc b/tensorflow/core/distributed_runtime/local_master.cc
index aaa4cfa734..76315462a7 100644
--- a/tensorflow/core/distributed_runtime/local_master.cc
+++ b/tensorflow/core/distributed_runtime/local_master.cc
@@ -157,6 +157,47 @@ Status LocalMaster::Reset(CallOptions* call_options,
return ret;
}
+Status LocalMaster::MakeCallable(CallOptions* call_options,
+ const MakeCallableRequest* request,
+ MakeCallableResponse* response) {
+ Notification n;
+ Status ret;
+ master_impl_->MakeCallable(request, response, [&n, &ret](const Status& s) {
+ ret.Update(s);
+ n.Notify();
+ });
+ TF_RETURN_IF_ERROR(
+ WaitForNotification(call_options, default_timeout_in_ms_, &n));
+ return ret;
+}
+Status LocalMaster::RunCallable(CallOptions* call_options,
+ const RunCallableRequest* request,
+ RunCallableResponse* response) {
+ Notification n;
+ Status ret;
+ master_impl_->RunCallable(call_options, request, response,
+ [&n, &ret](const Status& s) {
+ ret.Update(s);
+ n.Notify();
+ });
+ TF_RETURN_IF_ERROR(
+ WaitForNotification(call_options, default_timeout_in_ms_, &n));
+ return ret;
+}
+Status LocalMaster::ReleaseCallable(CallOptions* call_options,
+ const ReleaseCallableRequest* request,
+ ReleaseCallableResponse* response) {
+ Notification n;
+ Status ret;
+ master_impl_->ReleaseCallable(request, response, [&n, &ret](const Status& s) {
+ ret.Update(s);
+ n.Notify();
+ });
+ TF_RETURN_IF_ERROR(
+ WaitForNotification(call_options, default_timeout_in_ms_, &n));
+ return ret;
+}
+
namespace {
mutex* get_local_master_registry_lock() {
static mutex local_master_registry_lock(LINKER_INITIALIZED);
diff --git a/tensorflow/core/distributed_runtime/local_master.h b/tensorflow/core/distributed_runtime/local_master.h
index c20b40329a..cad6babad8 100644
--- a/tensorflow/core/distributed_runtime/local_master.h
+++ b/tensorflow/core/distributed_runtime/local_master.h
@@ -71,6 +71,16 @@ class LocalMaster : public MasterInterface {
Status Reset(CallOptions* call_options, const ResetRequest* request,
ResetResponse* response) override;
+ Status MakeCallable(CallOptions* call_options,
+ const MakeCallableRequest* request,
+ MakeCallableResponse* response) override;
+ Status RunCallable(CallOptions* call_options,
+ const RunCallableRequest* request,
+ RunCallableResponse* response) override;
+ Status ReleaseCallable(CallOptions* call_options,
+ const ReleaseCallableRequest* request,
+ ReleaseCallableResponse* response);
+
// Registers the mapping from the given `target` to the given `master`.
//
// WARNING: The `master` pointer remains owned by the caller. It is
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index 1a488303ac..f47502e844 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -611,4 +611,55 @@ void Master::Reset(const ResetRequest* req, ResetResponse* resp,
});
}
+void Master::MakeCallable(const MakeCallableRequest* req,
+ MakeCallableResponse* resp, MyClosure done) {
+ auto session = FindMasterSession(req->session_handle());
+ if (session == nullptr) {
+ done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+ return;
+ }
+
+ SchedClosure(std::bind(
+ [this, session, req, resp](MyClosure done) {
+ Status s = session->MakeCallable(*req, resp);
+ session->Unref();
+ done(s);
+ },
+ std::move(done)));
+}
+
+void Master::RunCallable(CallOptions* opts, const RunCallableRequest* req,
+ RunCallableResponse* resp, MyClosure done) {
+ auto session = FindMasterSession(req->session_handle());
+ if (session == nullptr) {
+ done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+ return;
+ }
+
+ SchedClosure(std::bind(
+ [this, session, opts, req, resp](MyClosure done) {
+ Status s = session->RunCallable(opts, *req, resp);
+ session->Unref();
+ done(s);
+ },
+ std::move(done)));
+}
+
+void Master::ReleaseCallable(const ReleaseCallableRequest* req,
+ ReleaseCallableResponse* resp, MyClosure done) {
+ auto session = FindMasterSession(req->session_handle());
+ if (session == nullptr) {
+ done(errors::Aborted("Session ", req->session_handle(), " is not found."));
+ return;
+ }
+
+ SchedClosure(std::bind(
+ [this, session, req, resp](MyClosure done) {
+ Status s = session->ReleaseCallable(*req, resp);
+ session->Unref();
+ done(s);
+ },
+ std::move(done)));
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/master.h b/tensorflow/core/distributed_runtime/master.h
index 678fc46bd7..dbb337fd48 100644
--- a/tensorflow/core/distributed_runtime/master.h
+++ b/tensorflow/core/distributed_runtime/master.h
@@ -61,6 +61,13 @@ class Master {
// See tensorflow::Reset() and the comment on ResetRequest.
void Reset(const ResetRequest* req, ResetResponse* resp, MyClosure done);
+ void MakeCallable(const MakeCallableRequest* req, MakeCallableResponse* resp,
+ MyClosure done);
+ void RunCallable(CallOptions* opts, const RunCallableRequest* req,
+ RunCallableResponse* resp, MyClosure done);
+ void ReleaseCallable(const ReleaseCallableRequest* req,
+ ReleaseCallableResponse* resp, MyClosure done);
+
private:
typedef Master ME;
diff --git a/tensorflow/core/distributed_runtime/master_interface.h b/tensorflow/core/distributed_runtime/master_interface.h
index bf6a2db3e2..a8ae3cba3c 100644
--- a/tensorflow/core/distributed_runtime/master_interface.h
+++ b/tensorflow/core/distributed_runtime/master_interface.h
@@ -89,6 +89,16 @@ class MasterInterface {
virtual Status Reset(CallOptions* call_options, const ResetRequest* request,
ResetResponse* response) = 0;
+ virtual Status MakeCallable(CallOptions* call_options,
+ const MakeCallableRequest* request,
+ MakeCallableResponse* response) = 0;
+ virtual Status RunCallable(CallOptions* call_options,
+ const RunCallableRequest* request,
+ RunCallableResponse* response) = 0;
+ virtual Status ReleaseCallable(CallOptions* call_options,
+ const ReleaseCallableRequest* request,
+ ReleaseCallableResponse* response) = 0;
+
protected:
// NOTE: This should only be called by implementations of this
// interface whose CreateRunStepResponse() method returns a
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 64adf35c5e..e0a5bb4c53 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -72,7 +72,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
client_graph_(std::move(cg)),
session_opts_(session_opts),
is_partial_(is_partial),
- debug_opts_(bopts.callable_options.run_options().debug_options()),
+ callable_opts_(bopts.callable_options),
worker_cache_(worker_cache),
should_deregister_(should_deregister) {
VLOG(1) << "Created ReffedClientGraph for node with "
@@ -94,12 +94,18 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const ClientGraph* client_graph() { return client_graph_.get(); }
+ const CallableOptions& callable_options() { return callable_opts_; }
+
std::unique_ptr<ProfileHandler> GetProfileHandler(uint64 step,
int64 execution_count,
const RunOptions& ropts) {
return stats_publisher_->GetProfileHandler(step, execution_count, ropts);
}
+ int64 get_and_increment_execution_count() {
+ return execution_count_.fetch_add(1);
+ }
+
// Turn RPC logging on or off, both at the WorkerCache used by this
// master process, and at each remote worker in use for the current
// partitions.
@@ -178,6 +184,10 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp,
CancellationManager* cm, const bool is_last_partial_run);
+ Status RunPartitions(const MasterEnv* env, int64 step_id,
+ int64 execution_count, PerStepState* pss,
+ CallOptions* call_opts, const RunCallableRequest& req,
+ RunCallableResponse* resp, CancellationManager* cm);
// Calls workers to cleanup states for the step "step_id". Calls
// `done` when all cleanup RPCs have completed.
@@ -211,10 +221,11 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const std::unique_ptr<ClientGraph> client_graph_;
const SessionOptions session_opts_;
const bool is_partial_;
- const DebugOptions& debug_opts_;
+ const CallableOptions callable_opts_;
WorkerCacheInterface* const worker_cache_; // Not owned.
std::unordered_map<StringPiece, Node*, StringPieceHasher> name_to_node_;
const bool should_deregister_;
+ std::atomic<int64> execution_count_ = {0};
// Graph partitioned into per-location subgraphs.
struct Part {
@@ -269,6 +280,17 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
const PartitionOptions& popts,
std::unordered_map<string, GraphDef> graph_partitions);
+ // Prepares a number of calls to workers. One call per partition.
+ // This is a generic method that handles Run, PartialRun, and RunCallable.
+ template <class FetchListType, class ClientRequestType,
+ class ClientResponseType>
+ Status RunPartitionsHelper(
+ const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
+ const FetchListType& fetches, const MasterEnv* env, int64 step_id,
+ int64 execution_count, PerStepState* pss, CallOptions* call_opts,
+ const ClientRequestType& req, ClientResponseType* resp,
+ CancellationManager* cm, bool is_last_partial_run);
+
// Deregisters the partitions on the workers. Called in the
// destructor and does not wait for the rpc completion.
void DeregisterPartitions();
@@ -411,7 +433,8 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
c->req.set_session_handle(session_handle_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
- *c->req.mutable_debug_options() = debug_opts_;
+ *c->req.mutable_debug_options() =
+ callable_opts_.run_options().debug_options();
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -490,24 +513,46 @@ class RunManyGraphs {
TF_DISALLOW_COPY_AND_ASSIGN(RunManyGraphs);
};
-Status MasterSession::ReffedClientGraph::RunPartitions(
- const MasterEnv* env, int64 step_id, int64 execution_count,
- PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
- MutableRunStepResponseWrapper* resp, CancellationManager* cm,
- const bool is_last_partial_run) {
- VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
- << execution_count;
- // Maps the names of fed tensors to their index in `req`.
- std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+namespace {
+Status AddSendFromClientRequest(const RunStepRequestWrapper& client_req,
+ MutableRunGraphRequestWrapper* worker_req,
+ size_t index, const string& send_key) {
+ return worker_req->AddSendFromRunStepRequest(client_req, index, send_key);
+}
- for (size_t i = 0; i < req.num_feeds(); ++i) {
- if (!feeds.insert({req.feed_name(i), i}).second) {
- return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
- }
- }
+Status AddSendFromClientRequest(const RunCallableRequest& client_req,
+ MutableRunGraphRequestWrapper* worker_req,
+ size_t index, const string& send_key) {
+ return worker_req->AddSendFromRunCallableRequest(client_req, index, send_key);
+}
- // Prepares a number of calls to workers. One call per partition.
+// TODO(mrry): Add a full-fledged wrapper that avoids TensorProto copies for
+// in-process messages.
+struct RunCallableResponseWrapper {
+ RunCallableResponse* resp; // Not owned.
+ std::unordered_map<string, TensorProto> fetch_key_to_protos;
+
+ RunMetadata* mutable_metadata() { return resp->mutable_metadata(); }
+ Status AddTensorFromRunGraphResponse(
+ const string& tensor_name, MutableRunGraphResponseWrapper* worker_resp,
+ size_t index) {
+ // TODO(b/74355905): Add a specialized implementation that avoids
+ // copying the tensor into the RunCallableResponse when at least
+ // two of the {client, master, worker} are in the same process.
+ return worker_resp->RecvValue(index, &fetch_key_to_protos[tensor_name]);
+ }
+};
+} // namespace
+
+template <class FetchListType, class ClientRequestType,
+ class ClientResponseType>
+Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
+ const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
+ const FetchListType& fetches, const MasterEnv* env, int64 step_id,
+ int64 execution_count, PerStepState* pss, CallOptions* call_opts,
+ const ClientRequestType& req, ClientResponseType* resp,
+ CancellationManager* cm, bool is_last_partial_run) {
// Collect execution cost stats on a smoothly decreasing frequency.
ExecutorOpts exec_opts;
if (pss->report_tensor_allocations_upon_oom) {
@@ -553,28 +598,19 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
// We keep these as separate paths for now, to ensure we aren't
// inadvertently slowing down the normal run path.
if (is_partial_) {
- for (size_t i = 0; i < req.num_feeds(); ++i) {
- const string& name = req.feed_name(i);
- const auto iter = part.feed_key.find(name);
+ for (const auto& name_index : feeds) {
+ const auto iter = part.feed_key.find(name_index.first.ToString());
if (iter == part.feed_key.end()) {
// The provided feed must be for a different partition.
continue;
}
const string& key = iter->second;
- auto feeds_iter = feeds.find(name);
- if (feeds_iter == feeds.end()) {
- return errors::InvalidArgument("No feed is provided for feed=", name,
- ", key=", key);
- } else if (feeds_iter->second != static_cast<size_t>(i)) {
- return errors::Internal("Cannot find feed named \"", name,
- " in request.");
- }
- TF_RETURN_IF_ERROR(c->req->AddSendFromRunStepRequest(req, i, key));
+ TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
+ name_index.second, key));
}
// TODO(suharshs): Make a map from feed to fetch_key to make this faster.
// For now, we just iterate through partitions to find the matching key.
- for (int i = 0; static_cast<size_t>(i) < req.num_fetches(); ++i) {
- const string& req_fetch = req.fetch_name(i);
+ for (const string& req_fetch : fetches) {
for (const auto& key_fetch : part.key_fetch) {
if (key_fetch.second == req_fetch) {
c->req->add_recv_key(key_fetch.first);
@@ -586,9 +622,13 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
for (const auto& feed_key : part.feed_key) {
const string& feed = feed_key.first;
const string& key = feed_key.second;
- const int64 feed_index = feeds[feed];
+ auto iter = feeds.find(feed);
+ if (iter == feeds.end()) {
+ return errors::Internal("No feed index found for feed: ", feed);
+ }
+ const int64 feed_index = iter->second;
TF_RETURN_IF_ERROR(
- c->req->AddSendFromRunStepRequest(req, feed_index, key));
+ AddSendFromClientRequest(req, c->req.get(), feed_index, key));
}
for (const auto& key_fetch : part.key_fetch) {
const string& key = key_fetch.first;
@@ -622,50 +662,115 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
} else {
return errors::Cancelled("Step was cancelled");
}
+ TF_RETURN_IF_ERROR(calls.status());
- // Collects fetches.
- Status status = calls.status();
- if (status.ok()) {
- for (int i = 0; i < num; ++i) {
- const Part& part = partitions_[i];
- MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
- for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
- auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
- if (iter == part.key_fetch.end()) {
- status.Update(errors::Internal("Unexpected fetch key: ",
- run_graph_resp->recv_key(j)));
- break;
- }
- const string& fetch = iter->second;
- status.Update(
- resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
- if (!status.ok()) {
- break;
- }
+ // Collects fetches and metadata.
+ Status status;
+ for (int i = 0; i < num; ++i) {
+ const Part& part = partitions_[i];
+ MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
+ for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
+ auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
+ if (iter == part.key_fetch.end()) {
+ status.Update(errors::Internal("Unexpected fetch key: ",
+ run_graph_resp->recv_key(j)));
+ break;
}
- if (pss->collect_timeline) {
- pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
+ const string& fetch = iter->second;
+ status.Update(
+ resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
+ if (!status.ok()) {
+ break;
}
- if (pss->collect_costs) {
- CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
- for (int j = 0; j < cost_graph->node_size(); ++j) {
- resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
- cost_graph->mutable_node(j));
- }
+ }
+ if (pss->collect_timeline) {
+ pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
+ }
+ if (pss->collect_costs) {
+ CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
+ for (int j = 0; j < cost_graph->node_size(); ++j) {
+ resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
+ cost_graph->mutable_node(j));
}
- if (pss->collect_partition_graphs) {
- protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
- resp->mutable_metadata()->mutable_partition_graphs();
- for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
- partition_graph_defs->Add()->Swap(
- run_graph_resp->mutable_partition_graph(i));
- }
+ }
+ if (pss->collect_partition_graphs) {
+ protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
+ resp->mutable_metadata()->mutable_partition_graphs();
+ for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
+ partition_graph_defs->Add()->Swap(
+ run_graph_resp->mutable_partition_graph(i));
}
}
}
return status;
}
+Status MasterSession::ReffedClientGraph::RunPartitions(
+ const MasterEnv* env, int64 step_id, int64 execution_count,
+ PerStepState* pss, CallOptions* call_opts, const RunStepRequestWrapper& req,
+ MutableRunStepResponseWrapper* resp, CancellationManager* cm,
+ const bool is_last_partial_run) {
+ VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
+ << execution_count;
+ // Maps the names of fed tensors to their index in `req`.
+ std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+ for (size_t i = 0; i < req.num_feeds(); ++i) {
+ if (!feeds.insert({req.feed_name(i), i}).second) {
+ return errors::InvalidArgument("Duplicated feeds: ", req.feed_name(i));
+ }
+ }
+
+ std::vector<string> fetches;
+ fetches.reserve(req.num_fetches());
+ for (size_t i = 0; i < req.num_fetches(); ++i) {
+ fetches.push_back(req.fetch_name(i));
+ }
+
+ return RunPartitionsHelper(feeds, fetches, env, step_id, execution_count, pss,
+ call_opts, req, resp, cm, is_last_partial_run);
+}
+
+Status MasterSession::ReffedClientGraph::RunPartitions(
+ const MasterEnv* env, int64 step_id, int64 execution_count,
+ PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
+ RunCallableResponse* resp, CancellationManager* cm) {
+ VLOG(2) << "RunPartitions step_id " << step_id << " execution_count "
+ << execution_count;
+ // Maps the names of fed tensors to their index in `req`.
+ std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
+ for (size_t i = 0; i < callable_opts_.feed_size(); ++i) {
+ if (!feeds.insert({callable_opts_.feed(i), i}).second) {
+ // MakeCallable will fail if there are two feeds with the same name.
+ return errors::Internal("Duplicated feeds in callable: ",
+ callable_opts_.feed(i));
+ }
+ }
+
+ // Create a wrapped response object to collect the fetched values and
+ // rearrange them for the RunCallableResponse.
+ RunCallableResponseWrapper wrapped_resp;
+ wrapped_resp.resp = resp;
+
+ TF_RETURN_IF_ERROR(RunPartitionsHelper(
+ feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
+ call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
+
+ // Collects fetches.
+ // TODO(b/74355905): Add a specialized implementation that avoids
+ // copying the tensor into the RunCallableResponse when at least
+ // two of the {client, master, worker} are in the same process.
+ for (const string& fetch : callable_opts_.fetch()) {
+ TensorProto* fetch_proto = resp->mutable_fetch()->Add();
+ auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
+ if (iter == wrapped_resp.fetch_key_to_protos.end()) {
+ return errors::Internal("Worker did not return a value for fetch: ",
+ fetch);
+ }
+ fetch_proto->Swap(&iter->second);
+ }
+ return Status::OK();
+}
+
namespace {
class CleanupBroadcastHelper {
@@ -1266,15 +1371,11 @@ WorkerCacheInterface* MasterSession::get_worker_cache() const {
return env_->worker_cache;
}
-Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
- ReffedClientGraph** rcg, bool is_partial) {
+Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
+ ReffedClientGraph** out_rcg, int64* out_count) {
const uint64 hash = HashBuildGraphOptions(opts);
{
mutex_lock l(mu_);
- // Keep track of how many times this subgraph has been executed in
- // this session.
- int64* c = &subgraph_execution_counts_[hash];
- *count = (*c)++;
// TODO(suharshs): We cache partial run graphs and run graphs separately
// because there is preprocessing that needs to only be run for partial
// run calls.
@@ -1296,8 +1397,9 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
- *rcg = iter->second;
- (*rcg)->Ref();
+ *out_rcg = iter->second;
+ (*out_rcg)->Ref();
+ *out_count = (*out_rcg)->get_and_increment_execution_count();
}
return Status::OK();
}
@@ -1316,6 +1418,12 @@ void MasterSession::ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
rcg_map->clear();
}
+namespace {
+uint64 MakeStepId() {
+ return (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+}
+} // namespace
+
Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
PartialRunSetupResponse* resp) {
std::vector<string> inputs, outputs, targets;
@@ -1332,15 +1440,15 @@ Status MasterSession::PartialRunSetup(const PartialRunSetupRequest* req,
string handle = std::to_string(partial_run_handle_counter_.fetch_add(1));
ReffedClientGraph* rcg = nullptr;
- int64 count = 0;
// Prepare.
BuildGraphOptions opts;
BuildBuildGraphOptions(*req, &opts);
- TF_RETURN_IF_ERROR(StartStep(opts, &count, &rcg, true));
+ int64 count;
+ TF_RETURN_IF_ERROR(StartStep(opts, true, &rcg, &count));
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+ const uint64 step_id = MakeStepId();
TRACEPRINTF("stepid %llu", step_id);
rcg->Ref();
@@ -1585,6 +1693,73 @@ Status MasterSession::CreateDebuggerState(
return Status::OK();
}
+void MasterSession::FillPerStepState(MasterSession::ReffedClientGraph* rcg,
+ const RunOptions& run_options,
+ uint64 step_id, int64 count,
+ PerStepState* out_pss,
+ std::unique_ptr<ProfileHandler>* out_ph) {
+ out_pss->collect_timeline =
+ run_options.trace_level() == RunOptions::FULL_TRACE;
+ out_pss->collect_rpcs = run_options.trace_level() == RunOptions::FULL_TRACE;
+ out_pss->report_tensor_allocations_upon_oom =
+ run_options.report_tensor_allocations_upon_oom();
+ // Build the cost model every 'build_cost_model_every' steps after skipping an
+ // initial 'build_cost_model_after' steps.
+ const int64 build_cost_model_after =
+ session_opts_.config.graph_options().build_cost_model_after();
+ const int64 build_cost_model_every =
+ session_opts_.config.graph_options().build_cost_model();
+ out_pss->collect_costs =
+ build_cost_model_every > 0 &&
+ ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
+ out_pss->collect_partition_graphs = run_options.output_partition_graphs();
+
+ *out_ph = rcg->GetProfileHandler(step_id, count, run_options);
+ if (*out_ph) {
+ out_pss->collect_timeline = true;
+ out_pss->collect_rpcs = (*out_ph)->should_collect_rpcs();
+ }
+}
+
+Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
+ uint64 step_id,
+ const RunOptions& run_options,
+ PerStepState* pss,
+ const std::unique_ptr<ProfileHandler>& ph,
+ const Status& run_status,
+ RunMetadata* out_run_metadata) {
+ Status s = run_status;
+ if (s.ok()) {
+ pss->end_micros = Env::Default()->NowMicros();
+
+ // Schedule post-processing and cleanup to be done asynchronously.
+ rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
+ } else if (errors::IsCancelled(s)) {
+ mutex_lock l(mu_);
+ if (closed_) {
+ if (garbage_collected_) {
+ s = errors::Cancelled(
+ "Step was cancelled because the session was garbage collected due "
+ "to inactivity.");
+ } else {
+ s = errors::Cancelled(
+ "Step was cancelled by an explicit call to `Session::Close()`.");
+ }
+ }
+ }
+ Ref();
+ rcg->Ref();
+ rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
+ if (!s.ok()) {
+ LOG(ERROR) << "Cleanup partition error: " << s;
+ }
+ rcg->Unref();
+ MarkRunCompletion();
+ Unref();
+ });
+ return s;
+}
+
Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
@@ -1597,8 +1772,8 @@ Status MasterSession::DoRunWithLocalExecution(
BuildGraphOptions bgopts;
BuildBuildGraphOptions(req, &bgopts);
ReffedClientGraph* rcg = nullptr;
- int64 count = 0;
- TF_RETURN_IF_ERROR(StartStep(bgopts, &count, &rcg, false));
+ int64 count;
+ TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
// Unref "rcg" when out of scope.
core::ScopedUnref unref(rcg);
@@ -1614,64 +1789,133 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- const uint64 step_id = (random::New64() & ((1uLL << 56) - 1)) | (1uLL << 56);
+ const uint64 step_id = MakeStepId();
TRACEPRINTF("stepid %llu", step_id);
- pss.collect_timeline = req.options().trace_level() == RunOptions::FULL_TRACE;
- pss.collect_rpcs = req.options().trace_level() == RunOptions::FULL_TRACE;
- pss.report_tensor_allocations_upon_oom =
- req.options().report_tensor_allocations_upon_oom();
- // Build the cost model every 'build_cost_model_every' steps after skipping an
- // initial 'build_cost_model_after' steps.
- const int64 build_cost_model_after =
- session_opts_.config.graph_options().build_cost_model_after();
- const int64 build_cost_model_every =
- session_opts_.config.graph_options().build_cost_model();
- pss.collect_costs =
- build_cost_model_every > 0 &&
- ((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
- pss.collect_partition_graphs = req.options().output_partition_graphs();
+ std::unique_ptr<ProfileHandler> ph;
+ FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
- std::unique_ptr<ProfileHandler> ph =
- rcg->GetProfileHandler(step_id, count, req.options());
- if (ph) {
- pss.collect_timeline = true;
- pss.collect_rpcs = ph->should_collect_rpcs();
+ Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
+ &cancellation_manager_, false);
+ cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
+ return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
+ resp->mutable_metadata());
+}
+
+Status MasterSession::MakeCallable(const MakeCallableRequest& req,
+ MakeCallableResponse* resp) {
+ UpdateLastAccessTime();
+
+ BuildGraphOptions opts;
+ opts.callable_options = req.options();
+ opts.use_function_convention = false;
+
+ ReffedClientGraph* callable;
+
+ {
+ mutex_lock l(mu_);
+ if (closed_) {
+ return errors::FailedPrecondition("Session is closed.");
+ }
+ std::unique_ptr<ClientGraph> client_graph;
+ TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
+ callable = new ReffedClientGraph(handle_, opts, std::move(client_graph),
+ session_opts_, stats_publisher_factory_,
+ false /* is_partial */, get_worker_cache(),
+ !should_delete_worker_sessions_);
+ }
+
+ Status s = BuildAndRegisterPartitions(callable);
+ if (!s.ok()) {
+ callable->Unref();
+ return s;
}
+ uint64 handle;
+ {
+ mutex_lock l(mu_);
+ handle = next_callable_handle_++;
+ callables_[handle] = callable;
+ }
+
+ resp->set_handle(handle);
+ return Status::OK();
+}
+
+Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
+ const RunCallableRequest& req,
+ RunCallableResponse* resp) {
+ VLOG(2) << "DoRunCallable req: " << req.DebugString();
+ PerStepState pss;
+ pss.start_micros = Env::Default()->NowMicros();
+ auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
+
+ // Prepare.
+ int64 count = rcg->get_and_increment_execution_count();
+
+ // Keeps the highest 8 bits 0x01: we reserve some bits of the
+ // step_id for future use.
+ const uint64 step_id = MakeStepId();
+ TRACEPRINTF("stepid %llu", step_id);
+
+ const RunOptions& run_options = rcg->callable_options().run_options();
+
+ if (run_options.timeout_in_ms() != 0) {
+ opts->SetTimeout(run_options.timeout_in_ms());
+ }
+
+ std::unique_ptr<ProfileHandler> ph;
+ FillPerStepState(rcg, run_options, step_id, count, &pss, &ph);
Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
- &cancellation_manager_, false);
- if (s.ok()) {
- pss.end_micros = Env::Default()->NowMicros();
+ &cancellation_manager_);
+ cleanup.release(); // MarkRunCompletion called in PostRunCleanup().
+ return PostRunCleanup(rcg, step_id, run_options, &pss, ph, s,
+ resp->mutable_metadata());
+}
- // Schedule post-processing and cleanup to be done asynchronously.
- rcg->ProcessStats(step_id, &pss, ph.get(), req.options(),
- resp->mutable_metadata());
- } else if (errors::IsCancelled(s)) {
+Status MasterSession::RunCallable(CallOptions* opts,
+ const RunCallableRequest& req,
+ RunCallableResponse* resp) {
+ UpdateLastAccessTime();
+ ReffedClientGraph* callable;
+ {
mutex_lock l(mu_);
if (closed_) {
- if (garbage_collected_) {
- s = errors::Cancelled(
- "Step was cancelled because the session was garbage collected due "
- "to inactivity.");
- } else {
- s = errors::Cancelled(
- "Step was cancelled by an explicit call to `Session::Close()`.");
- }
+ return errors::FailedPrecondition("Session is closed.");
+ }
+ int64 handle = req.handle();
+ if (handle >= next_callable_handle_) {
+ return errors::InvalidArgument("No such callable handle: ", handle);
+ }
+ auto iter = callables_.find(req.handle());
+ if (iter == callables_.end()) {
+ return errors::InvalidArgument(
+ "Attempted to run callable after handle was released: ", handle);
}
+ callable = iter->second;
+ callable->Ref();
+ ++num_running_;
}
- Ref();
- rcg->Ref();
- cleanup.release(); // MarkRunCompletion called in done closure.
- rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
- if (!s.ok()) {
- LOG(ERROR) << "Cleanup partition error: " << s;
+ core::ScopedUnref unref_callable(callable);
+ return DoRunCallable(opts, callable, req, resp);
+}
+
+Status MasterSession::ReleaseCallable(const ReleaseCallableRequest& req,
+ ReleaseCallableResponse* resp) {
+ UpdateLastAccessTime();
+ ReffedClientGraph* to_unref = nullptr;
+ {
+ mutex_lock l(mu_);
+ auto iter = callables_.find(req.handle());
+ if (iter != callables_.end()) {
+ to_unref = iter->second;
+ callables_.erase(iter);
}
- rcg->Unref();
- MarkRunCompletion();
- Unref();
- });
- return s;
+ }
+ if (to_unref != nullptr) {
+ to_unref->Unref();
+ }
+ return Status::OK();
}
Status MasterSession::Close() {
@@ -1688,6 +1932,7 @@ Status MasterSession::Close() {
}
ClearRunsTable(&to_unref, &run_graphs_);
ClearRunsTable(&to_unref, &partial_run_graphs_);
+ ClearRunsTable(&to_unref, &callables_);
}
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
if (should_delete_worker_sessions_) {
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index 4bd4e1367a..a05419904f 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -89,6 +89,15 @@ class MasterSession : public core::RefCounted {
Status ListDevices(ListDevicesResponse* resp) const;
+ Status MakeCallable(const MakeCallableRequest& req,
+ MakeCallableResponse* resp);
+
+ Status RunCallable(CallOptions* opts, const RunCallableRequest& req,
+ RunCallableResponse* resp);
+
+ Status ReleaseCallable(const ReleaseCallableRequest& req,
+ ReleaseCallableResponse* resp);
+
// Close this session and delete "*this". Returns OK if all known
// states are cleanup successfully.
//
@@ -140,6 +149,8 @@ class MasterSession : public core::RefCounted {
typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
RCGMap run_graphs_ GUARDED_BY(mu_);
RCGMap partial_run_graphs_ GUARDED_BY(mu_);
+ int64 next_callable_handle_ GUARDED_BY(mu_) = 0;
+ RCGMap callables_ GUARDED_BY(mu_);
struct PerStepState {
bool collect_costs = false;
@@ -205,15 +216,28 @@ class MasterSession : public core::RefCounted {
bool should_delete_worker_sessions_ = false;
Status DeleteWorkerSessions();
- Status StartStep(const BuildGraphOptions& opts, int64* count,
- ReffedClientGraph** graph, bool is_partial);
+ Status StartStep(const BuildGraphOptions& opts, bool is_partial,
+ ReffedClientGraph** out_rcg, int64* out_count);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
RCGMap* rcg_map) EXCLUSIVE_LOCKS_REQUIRED(mu_);
+ void FillPerStepState(MasterSession::ReffedClientGraph* rcg,
+ const RunOptions& run_options, uint64 step_id,
+ int64 count, PerStepState* out_pss,
+ std::unique_ptr<ProfileHandler>* out_ph);
Status DoRunWithLocalExecution(CallOptions* opts,
const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp);
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp);
+ Status DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
+ const RunCallableRequest& req,
+ RunCallableResponse* resp);
+ Status PostRunCleanup(MasterSession::ReffedClientGraph* rcg, uint64 step_id,
+ const RunOptions& run_options, PerStepState* pss,
+ const std::unique_ptr<ProfileHandler>& ph,
+ const Status& run_status,
+ RunMetadata* out_run_metadata);
+
void MarkRunCompletion();
void UpdateLastAccessTime();
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc
index 66ebb3080a..18668b44d3 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.cc
+++ b/tensorflow/core/distributed_runtime/message_wrappers.cc
@@ -326,6 +326,20 @@ Status InMemoryRunGraphRequest::AddSendFromRunStepRequest(
return Status::OK();
}
+// TODO(b/74355905): Add a specialized implementation that avoids
+// copying the tensor when at least two of the {client, master,
+// worker} are in the same process.
+Status InMemoryRunGraphRequest::AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) {
+ Tensor tensor;
+ if (!ParseTensorProtoToTensor(run_callable_request.feed(i), &tensor)) {
+ return errors::InvalidArgument("Invalid TensorProto for feed value ", i);
+ }
+ sends_.emplace_back(send_key, std::move(tensor));
+ return Status::OK();
+}
+
size_t InMemoryRunGraphRequest::num_recvs() const { return recvs_.size(); }
const string& InMemoryRunGraphRequest::recv_key(size_t i) const {
@@ -439,6 +453,18 @@ Status MutableProtoRunGraphRequest::AddSendFromRunStepRequest(
return Status::OK();
}
+// TODO(b/74355905): Add a specialized implementation that avoids
+// copying the tensor when at least two of the {client, master,
+// worker} are in the same process.
+Status MutableProtoRunGraphRequest::AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) {
+ NamedTensorProto* send = request_.add_send();
+ send->set_name(send_key);
+ *send->mutable_tensor() = run_callable_request.feed(i);
+ return Status::OK();
+}
+
size_t MutableProtoRunGraphRequest::num_recvs() const {
return request_.recv_key_size();
}
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 79fa6f926e..1f7cdb98a4 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -302,6 +302,9 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
virtual Status AddSendFromRunStepRequest(
const RunStepRequestWrapper& run_step_request, size_t i,
const string& send_key) = 0;
+ virtual Status AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) = 0;
virtual void add_recv_key(const string& recv_key) = 0;
virtual void set_is_partial(bool is_partial) = 0;
@@ -334,6 +337,9 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
Status AddSendFromRunStepRequest(
const RunStepRequestWrapper& run_step_request, size_t i,
const string& send_key) override;
+ Status AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) override;
void add_recv_key(const string& recv_key) override;
void set_is_partial(bool is_partial) override;
void set_is_last_partial_run(bool is_last_partial_run) override;
@@ -385,6 +391,9 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
Status AddSendFromRunStepRequest(
const RunStepRequestWrapper& run_step_request, size_t i,
const string& send_key) override;
+ Status AddSendFromRunCallableRequest(
+ const RunCallableRequest& run_callable_request, size_t i,
+ const string& send_key) override;
void add_recv_key(const string& recv_key) override;
void set_is_partial(bool is_partial) override;
void set_is_last_partial_run(bool is_last_partial_run) override;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
index 63745e8ebd..23968e24c8 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc
@@ -111,6 +111,11 @@ class GrpcMasterService : public AsyncServiceInterface {
ENQUEUE_REQUEST(CloseSession, false);
ENQUEUE_REQUEST(ListDevices, false);
ENQUEUE_REQUEST(Reset, false);
+ ENQUEUE_REQUEST(MakeCallable, false);
+ for (int i = 0; i < 100; ++i) {
+ ENQUEUE_REQUEST(RunCallable, true);
+ }
+ ENQUEUE_REQUEST(ReleaseCallable, false);
void* tag;
bool ok;
@@ -236,6 +241,47 @@ class GrpcMasterService : public AsyncServiceInterface {
});
ENQUEUE_REQUEST(Reset, false);
}
+
+ // RPC handler for making a callable.
+ void MakeCallableHandler(
+ MasterCall<MakeCallableRequest, MakeCallableResponse>* call) {
+ master_impl_->MakeCallable(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(MakeCallable, false);
+ }
+
+ // RPC handler for running a callable.
+ void RunCallableHandler(
+ MasterCall<RunCallableRequest, RunCallableResponse>* call) {
+ auto* trace = TraceRpc("RunCallable/Server", call->client_metadata());
+ CallOptions* call_opts = new CallOptions;
+ // The timeout may be overridden by a non-zero timeout in the
+ // callable's `RunOptions`; this overriding will happen inside the
+ // `MasterSession` implementation.
+ call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
+ call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
+ master_impl_->RunCallable(call_opts, &call->request, &call->response,
+ [call, call_opts, trace](const Status& status) {
+ call->ClearCancelCallback();
+ delete call_opts;
+ delete trace;
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(RunCallable, false);
+ }
+
+ // RPC handler for making a callable.
+ void ReleaseCallableHandler(
+ MasterCall<ReleaseCallableRequest, ReleaseCallableResponse>* call) {
+ master_impl_->ReleaseCallable(&call->request, &call->response,
+ [call](const Status& status) {
+ call->SendResponse(ToGrpcStatus(status));
+ });
+ ENQUEUE_REQUEST(ReleaseCallable, false);
+ }
+
#undef ENQUEUE_REQUEST
// Start tracing, including the ID attached to the RPC.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
index e2016e824c..c832adbbbf 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
@@ -36,6 +36,9 @@ static const char* grpcMasterService_method_names[] = {
"/tensorflow.MasterService/CloseSession",
"/tensorflow.MasterService/ListDevices",
"/tensorflow.MasterService/Reset",
+ "/tensorflow.MasterService/MakeCallable",
+ "/tensorflow.MasterService/RunCallable",
+ "/tensorflow.MasterService/ReleaseCallable",
};
std::unique_ptr<MasterService::Stub> MasterService::NewStub(
@@ -64,7 +67,14 @@ MasterService::Stub::Stub(
rpcmethod_ListDevices_(grpcMasterService_method_names[5],
::grpc::internal::RpcMethod::NORMAL_RPC, channel),
rpcmethod_Reset_(grpcMasterService_method_names[6],
- ::grpc::internal::RpcMethod::NORMAL_RPC, channel) {}
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_MakeCallable_(grpcMasterService_method_names[7],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_RunCallable_(grpcMasterService_method_names[8],
+ ::grpc::internal::RpcMethod::NORMAL_RPC, channel),
+ rpcmethod_ReleaseCallable_(grpcMasterService_method_names[9],
+ ::grpc::internal::RpcMethod::NORMAL_RPC,
+ channel) {}
::grpc::Status MasterService::Stub::CreateSession(
::grpc::ClientContext* context, const CreateSessionRequest& request,
@@ -115,8 +125,29 @@ MasterService::Stub::Stub(
context, request, response);
}
+::grpc::Status MasterService::Stub::MakeCallable(
+ ::grpc::ClientContext* context, const MakeCallableRequest& request,
+ MakeCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_MakeCallable_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::RunCallable(
+ ::grpc::ClientContext* context, const RunCallableRequest& request,
+ RunCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_RunCallable_, context, request, response);
+}
+
+::grpc::Status MasterService::Stub::ReleaseCallable(
+ ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) {
+ return ::grpc::internal::BlockingUnaryCall(
+ channel_.get(), rpcmethod_ReleaseCallable_, context, request, response);
+}
+
MasterService::AsyncService::AsyncService() {
- for (int i = 0; i < 7; ++i) {
+ for (int i = 0; i < 10; ++i) {
AddMethod(new ::grpc::internal::RpcServiceMethod(
grpcMasterService_method_names[i],
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
index 6ae94b7441..3c382738c4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
@@ -79,6 +79,15 @@ class MasterService final {
virtual ::grpc::Status Reset(::grpc::ClientContext* context,
const ResetRequest& request,
ResetResponse* response) = 0;
+ virtual ::grpc::Status MakeCallable(::grpc::ClientContext* context,
+ const MakeCallableRequest& request,
+ MakeCallableResponse* response) = 0;
+ virtual ::grpc::Status RunCallable(::grpc::ClientContext* context,
+ const RunCallableRequest& request,
+ RunCallableResponse* response) = 0;
+ virtual ::grpc::Status ReleaseCallable(
+ ::grpc::ClientContext* context, const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) = 0;
};
class Stub final : public StubInterface {
public:
@@ -104,6 +113,15 @@ class MasterService final {
::grpc::Status Reset(::grpc::ClientContext* context,
const ResetRequest& request,
ResetResponse* response) override;
+ ::grpc::Status MakeCallable(::grpc::ClientContext* context,
+ const MakeCallableRequest& request,
+ MakeCallableResponse* response) override;
+ ::grpc::Status RunCallable(::grpc::ClientContext* context,
+ const RunCallableRequest& request,
+ RunCallableResponse* response) override;
+ ::grpc::Status ReleaseCallable(::grpc::ClientContext* context,
+ const ReleaseCallableRequest& request,
+ ReleaseCallableResponse* response) override;
private:
std::shared_ptr< ::grpc::ChannelInterface> channel_;
@@ -114,6 +132,9 @@ class MasterService final {
const ::grpc::internal::RpcMethod rpcmethod_CloseSession_;
const ::grpc::internal::RpcMethod rpcmethod_ListDevices_;
const ::grpc::internal::RpcMethod rpcmethod_Reset_;
+ const ::grpc::internal::RpcMethod rpcmethod_MakeCallable_;
+ const ::grpc::internal::RpcMethod rpcmethod_RunCallable_;
+ const ::grpc::internal::RpcMethod rpcmethod_ReleaseCallable_;
};
static std::unique_ptr<Stub> NewStub(
const std::shared_ptr< ::grpc::ChannelInterface>& channel,
@@ -179,6 +200,30 @@ class MasterService final {
::grpc::Service::RequestAsyncUnary(6, context, request, response,
new_call_cq, notification_cq, tag);
}
+ void RequestMakeCallable(
+ ::grpc::ServerContext* context, MakeCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<MakeCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(7, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestRunCallable(
+ ::grpc::ServerContext* context, RunCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<RunCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(8, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
+ void RequestReleaseCallable(
+ ::grpc::ServerContext* context, ReleaseCallableRequest* request,
+ ::grpc::ServerAsyncResponseWriter<ReleaseCallableResponse>* response,
+ ::grpc::CompletionQueue* new_call_cq,
+ ::grpc::ServerCompletionQueue* notification_cq, void* tag) {
+ ::grpc::Service::RequestAsyncUnary(9, context, request, response,
+ new_call_cq, notification_cq, tag);
+ }
};
};
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
index 1088e9be66..1b92a79a67 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc
@@ -95,6 +95,28 @@ class GrpcRemoteMaster : public MasterInterface {
&MasterServiceStub::Reset);
}
+ Status MakeCallable(CallOptions* call_options,
+ const MakeCallableRequest* request,
+ MakeCallableResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::MakeCallable);
+ }
+ Status RunCallable(CallOptions* call_options,
+ const RunCallableRequest* request,
+ RunCallableResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::RunCallable);
+ }
+ Status ReleaseCallable(CallOptions* call_options,
+ const ReleaseCallableRequest* request,
+ ReleaseCallableResponse* response) override {
+ ::grpc::ClientContext ctx;
+ return Call(&ctx, call_options, request, response,
+ &MasterServiceStub::ReleaseCallable);
+ }
+
private:
// Start tracing, attaching a unique ID to both the trace and the RPC.
port::Tracing::TraceMe TraceRpc(StringPiece name,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
index 3e79a40683..fd1c150fa7 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.cc
@@ -91,6 +91,15 @@ void ReEncodeConsts(GraphDef* gdef) {
}
} // namespace
+Status GrpcSession::Handle(string* out_handle) {
+ mutex_lock l(mu_);
+ if (handle_.empty()) {
+ return errors::InvalidArgument("A session is not created yet....");
+ }
+ *out_handle = handle_;
+ return Status::OK();
+}
+
Status GrpcSession::CreateImpl(CallOptions* call_options,
const GraphDef& graph) {
{
@@ -274,14 +283,9 @@ Status GrpcSession::Run(const std::vector<std::pair<string, Tensor>>& inputs,
Status GrpcSession::RunProto(CallOptions* call_options,
MutableRunStepRequestWrapper* req,
MutableRunStepResponseWrapper* resp) {
- {
- mutex_lock l(mu_);
- if (handle_.empty()) {
- return errors::InvalidArgument("A session is not created yet....");
- }
-
- req->set_session_handle(handle_);
- }
+ string handle;
+ TF_RETURN_IF_ERROR(Handle(&handle));
+ req->set_session_handle(handle);
return master_->RunStep(call_options, req, resp);
}
@@ -293,14 +297,7 @@ Status GrpcSession::PRunSetup(const std::vector<string>& input_names,
PartialRunSetupRequest req;
PartialRunSetupResponse resp;
CallOptions call_options;
- {
- mutex_lock l(mu_);
- if (handle_.empty()) {
- return errors::InvalidArgument("A session is not created yet....");
- }
-
- req.set_session_handle(handle_);
- }
+ TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
for (const string& feed : input_names) {
req.add_feed(feed);
}
@@ -400,6 +397,55 @@ Status GrpcSession::Reset(const SessionOptions& options,
return ret;
}
+Status GrpcSession::MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) {
+ MakeCallableRequest req;
+ TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+ *req.mutable_options() = callable_options;
+ MakeCallableResponse resp;
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ TF_RETURN_IF_ERROR(master_->MakeCallable(&call_options, &req, &resp));
+ *out_handle = resp.handle();
+ return Status::OK();
+}
+
+Status GrpcSession::RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) {
+ RunCallableRequest req;
+ TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+ req.set_handle(handle);
+ for (const Tensor& feed : feed_tensors) {
+ feed.AsProtoTensorContent(req.mutable_feed()->Add());
+ }
+
+ RunCallableResponse resp;
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ TF_RETURN_IF_ERROR(master_->RunCallable(&call_options, &req, &resp));
+ for (const TensorProto& fetch : resp.fetch()) {
+ Tensor fetch_tensor;
+ if (!fetch_tensor.FromProto(cpu_allocator(), fetch)) {
+ return errors::Internal(
+ "Could not parse fetched tensor data in response from master.");
+ }
+ fetch_tensors->push_back(std::move(fetch_tensor));
+ }
+ return Status::OK();
+}
+
+Status GrpcSession::ReleaseCallable(CallableHandle handle) {
+ ReleaseCallableRequest req;
+ TF_RETURN_IF_ERROR(Handle(req.mutable_session_handle()));
+ req.set_handle(handle);
+ ReleaseCallableResponse resp;
+ CallOptions call_options;
+ call_options.SetTimeout(options_.config.operation_timeout_in_ms());
+ return master_->ReleaseCallable(&call_options, &req, &resp);
+}
+
class GrpcSessionFactory : public SessionFactory {
public:
bool AcceptsOptions(const SessionOptions& options) override {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session.h b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
index d87956a135..63795117f9 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session.h
@@ -82,20 +82,27 @@ class GrpcSession : public Session {
Status Close() override;
// NOTE: This API is still experimental and may change.
- ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- string* handle) override;
+ Status PRunSetup(const std::vector<string>& input_names,
+ const std::vector<string>& output_names,
+ const std::vector<string>& target_nodes,
+ string* handle) override;
// NOTE: This API is still experimental and may change.
- ::tensorflow::Status PRun(
- const string& handle,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- std::vector<Tensor>* outputs) override;
+ Status PRun(const string& handle,
+ const std::vector<std::pair<string, Tensor> >& inputs,
+ const std::vector<string>& output_names,
+ std::vector<Tensor>* outputs) override;
Status ListDevices(std::vector<DeviceAttributes>* response) override;
+ Status MakeCallable(const CallableOptions& callable_options,
+ CallableHandle* out_handle) override;
+ Status RunCallable(CallableHandle handle,
+ const std::vector<Tensor>& feed_tensors,
+ std::vector<Tensor>* fetch_tensors,
+ RunMetadata* run_metadata) override;
+ Status ReleaseCallable(CallableHandle handle) override;
+
protected:
// Takes ownership of `*master`.
void SetRemoteMaster(std::unique_ptr<MasterInterface> master);
@@ -111,6 +118,8 @@ class GrpcSession : public Session {
// The current version of the graph.
int64 current_graph_version_ GUARDED_BY(mu_);
+ Status Handle(string* out_handle) LOCKS_EXCLUDED(mu_);
+
Status RunHelper(const RunOptions& run_options,
const std::vector<std::pair<string, Tensor> >& inputs,
const std::vector<string>& output_tensor_names,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
index 335c3febe2..45b15a54a2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_session_test.cc
@@ -120,6 +120,49 @@ TEST(GrpcSessionTest, BasicNonProtoAPI) {
}
}
+TEST(GrpcSessionTest, BasicCallable) {
+ GraphDef graph;
+ string node_names[3];
+ // c = a * b
+ CreateGraphDef(&graph, node_names);
+
+ std::unique_ptr<test::TestCluster> cluster;
+ TF_CHECK_OK(test::TestCluster::MakeTestCluster(Devices(1, 0), 2, &cluster));
+
+ std::unique_ptr<Session> session(
+ NewRemote(Options(cluster->targets()[0], 1)));
+ ASSERT_TRUE(session != nullptr);
+
+ for (int iters = 0; iters < 25; ++iters) {
+ TF_CHECK_OK(session->Create(graph));
+ {
+ // Just run to target node
+ CallableOptions opts;
+ opts.add_target(node_names[2]);
+ Session::CallableHandle handle;
+ TF_CHECK_OK(session->MakeCallable(opts, &handle));
+ TF_CHECK_OK(session->RunCallable(handle, {}, nullptr, nullptr));
+ TF_CHECK_OK(session->ReleaseCallable(handle));
+ }
+ {
+ // Run to a target node and a real tensor
+ CallableOptions opts;
+ opts.add_target(node_names[1]);
+ opts.add_fetch(node_names[2] + ":0");
+ Session::CallableHandle handle;
+ TF_CHECK_OK(session->MakeCallable(opts, &handle));
+ std::vector<Tensor> outputs;
+ TF_CHECK_OK(session->RunCallable(handle, {}, &outputs, nullptr));
+ ASSERT_EQ(1, outputs.size());
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ ASSERT_EQ(4.0, outputs[0].flat<float>()(0));
+ TF_CHECK_OK(session->ReleaseCallable(handle));
+ }
+
+ TF_CHECK_OK(session->Close());
+ }
+}
+
TEST(GrpcSessionTest, BasicNonProtoAPIConsistentOrder) {
GraphDef graph;
string node_names[3];
diff --git a/tensorflow/core/protobuf/master.proto b/tensorflow/core/protobuf/master.proto
index 0437cb1b83..96c91536f7 100644
--- a/tensorflow/core/protobuf/master.proto
+++ b/tensorflow/core/protobuf/master.proto
@@ -23,6 +23,7 @@ option java_package = "org.tensorflow.distruntime";
import "tensorflow/core/framework/device_attributes.proto";
import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/framework/tensor.proto";
import "tensorflow/core/lib/core/error_codes.proto";
import "tensorflow/core/protobuf/config.proto";
import "tensorflow/core/protobuf/named_tensor.proto";
@@ -264,3 +265,70 @@ message ListDevicesResponse {
repeated DeviceAttributes local_device = 1;
repeated DeviceAttributes remote_device = 2;
}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// MakeCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message MakeCallableRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+
+ // Options that define the behavior of the created callable.
+ CallableOptions options = 2;
+}
+
+message MakeCallableResponse {
+ // A handle to the created callable.
+ int64 handle = 1;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// RunCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message RunCallableRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+ // REQUIRED: handle must be returned by a MakeCallable call to the same
+ // master service.
+ int64 handle = 2;
+
+ // Values of the tensors passed as arguments to the callable, in the order
+ // defined in the CallableOptions.feed field passed to MakeCallable.
+ repeated TensorProto feed = 3;
+}
+
+message RunCallableResponse {
+ // Values of the tensors returned by the callable, in the order defined in the
+ // CallableOptions.fetch field passed to MakeCallable.
+ repeated TensorProto fetch = 1;
+
+ // Returned metadata if requested in the options.
+ RunMetadata metadata = 2;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// ReleaseCallable method request/response protos.
+//
+////////////////////////////////////////////////////////////////////////////////
+
+message ReleaseCallableRequest {
+ // REQUIRED: session_handle must be returned by a CreateSession call
+ // to the same master service.
+ string session_handle = 1;
+
+ // REQUIRED: handle must be returned by a MakeCallable call to the same
+ // master service.
+ int64 handle = 2;
+}
+
+message ReleaseCallableResponse {
+}
diff --git a/tensorflow/core/protobuf/master_service.proto b/tensorflow/core/protobuf/master_service.proto
index 771c80562a..1170611f37 100644
--- a/tensorflow/core/protobuf/master_service.proto
+++ b/tensorflow/core/protobuf/master_service.proto
@@ -107,4 +107,13 @@ service MasterService {
// will no longer affect fresh ones via the resources in containers listed in
// the ResetRequest. See ResetRequest for more details.
rpc Reset(ResetRequest) returns (ResetResponse);
+
+ // Registers a callable for execution with RunCallable.
+ rpc MakeCallable(MakeCallableRequest) returns (MakeCallableResponse);
+
+ // Executes a callable registered with MakeCallable.
+ rpc RunCallable(RunCallableRequest) returns (RunCallableResponse);
+
+ // Frees resources associated with a callable registered with MakeCallable.
+ rpc ReleaseCallable(ReleaseCallableRequest) returns (ReleaseCallableResponse);
}