diff options
author | 2017-07-19 01:17:42 -0700 | |
---|---|---|
committer | 2017-07-19 01:22:05 -0700 | |
commit | a555b786bb5ca046ffe2bbb2b1615ab24851e954 (patch) | |
tree | e17496deac385af433a90f2cda31b3b0fb645e59 | |
parent | 9b28e435c1507be52129ec22d9d006d12b1e79f3 (diff) |
Add output_partitions support in distributed runtime.
PiperOrigin-RevId: 162456565
-rw-r--r-- | tensorflow/core/distributed_runtime/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/graph_mgr.cc | 17 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/graph_mgr.h | 5 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/master_session.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/master_session.h | 1 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/message_wrappers.cc | 41 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/message_wrappers.h | 14 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/message_wrappers_test.cc | 15 | ||||
-rw-r--r-- | tensorflow/core/distributed_runtime/worker.cc | 5 | ||||
-rw-r--r-- | tensorflow/core/protobuf/worker.proto | 7 | ||||
-rw-r--r-- | tensorflow/python/client/session_test.py | 16 |
11 files changed, 134 insertions, 16 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 89a5f433bd..d9ed40c50a 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -359,6 +359,7 @@ cc_library( srcs = ["graph_mgr.cc"], hdrs = ["graph_mgr.h"], deps = [ + ":message_wrappers", ":rendezvous_mgr_interface", ":worker_env", "//tensorflow/core:core_cpu_internal", diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 205843d342..7f77bf8b4e 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -442,10 +442,9 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, } void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, - WorkerSession* session, - const ExecutorOpts& /*opts*/, + WorkerSession* session, const ExecutorOpts& opts, StepStatsCollector* collector, - CostGraphDef* cost_graph, + MutableRunGraphResponseWrapper* response, CancellationManager* cancellation_manager, const NamedTensors& in, StatusCallback done) { // Lookup an item. Holds one ref while executing. @@ -464,6 +463,18 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, return; } + CostGraphDef* cost_graph = nullptr; + if (response != nullptr) { + cost_graph = response->mutable_cost_graph(); + if (opts.record_partition_graphs()) { + for (const ExecutionUnit& unit : item->units) { + GraphDef graph_def; + unit.graph->ToGraphDef(&graph_def); + response->AddPartitionGraph(graph_def); + } + } + } + RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); Status s = rendezvous->Initialize(session); diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index 4ee3711d02..fb83720d2a 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/costmodel_manager.h" #include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/distributed_runtime/message_wrappers.h" #include "tensorflow/core/distributed_runtime/worker_env.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/framework/cost_graph.pb.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/protobuf/debug.pb.h" +#include "tensorflow/core/protobuf/worker.pb.h" namespace tensorflow { @@ -80,7 +82,8 @@ class GraphMgr { typedef std::function<void(const Status&)> StatusCallback; void ExecuteAsync(const string& handle, const int64 step_id, WorkerSession* session, const ExecutorOpts& opts, - StepStatsCollector* collector, CostGraphDef* cost_graph, + StepStatsCollector* collector, + MutableRunGraphResponseWrapper* response, CancellationManager* cancellation_manager, const NamedTensors& in, StatusCallback done); diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 2b6e0c5268..361e89290d 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -513,6 +513,9 @@ Status MasterSession::ReffedClientGraph::RunPartitions( if (pss->collect_rpcs) { SetRPCLogging(true); } + if (pss->collect_partition_graphs) { + exec_opts.set_record_partition_graphs(true); + } if (pss->collect_costs || pss->collect_timeline) { pss->step_stats.resize(partitions_.size()); } @@ -615,30 +618,39 @@ Status MasterSession::ReffedClientGraph::RunPartitions( if (status.ok()) { for (int i = 0; i < num; ++i) { const Part& part = partitions_[i]; - for (size_t j = 0; j < calls.get(i)->resp->num_recvs(); ++j) { - auto iter = part.key_fetch.find(calls.get(i)->resp->recv_key(j)); + MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get(); + for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) { + auto iter = part.key_fetch.find(run_graph_resp->recv_key(j)); if (iter == part.key_fetch.end()) { status.Update(errors::Internal("Unexpected fetch key: ", - calls.get(i)->resp->recv_key(j))); + run_graph_resp->recv_key(j))); break; } const string& fetch = iter->second; - status.Update(resp->AddTensorFromRunGraphResponse( - fetch, calls.get(i)->resp.get(), j)); + status.Update( + resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j)); if (!status.ok()) { break; } } if (pss->collect_timeline) { - pss->step_stats[i].Swap(calls.get(i)->resp->mutable_step_stats()); + pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats()); } if (pss->collect_costs) { - CostGraphDef* cost_graph = calls.get(i)->resp->mutable_cost_graph(); + CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph(); for (int j = 0; j < cost_graph->node_size(); ++j) { resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap( cost_graph->mutable_node(j)); } } + if (pss->collect_partition_graphs) { + protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = + resp->mutable_metadata()->mutable_partition_graphs(); + for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) { + partition_graph_defs->Add()->Swap( + run_graph_resp->mutable_partition_graph(i)); + } + } } } return status; @@ -1361,6 +1373,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts, pss.collect_costs = build_cost_model_every > 0 && ((count + 1 - build_cost_model_after) % build_cost_model_every == 0); + pss.collect_partition_graphs = req.options().output_partition_graphs(); std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler( run_state->step_id, count, req.options()); @@ -1517,6 +1530,7 @@ Status MasterSession::DoRunWithLocalExecution( pss.collect_costs = build_cost_model_every > 0 && ((count + 1 - build_cost_model_after) % build_cost_model_every == 0); + pss.collect_partition_graphs = req.options().output_partition_graphs(); std::unique_ptr<ProfileHandler> ph = rcg->GetProfileHandler(step_id, count, req.options()); diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h index 10fc4868ca..33b9bfe631 100644 --- a/tensorflow/core/distributed_runtime/master_session.h +++ b/tensorflow/core/distributed_runtime/master_session.h @@ -145,6 +145,7 @@ class MasterSession : public core::RefCounted { bool collect_costs = false; bool collect_timeline = false; bool collect_rpcs = false; + bool collect_partition_graphs = false; Microseconds start_micros = Microseconds(0); Microseconds end_micros = Microseconds(0); std::vector<StepStats> step_stats; // per partition diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc index b5b564375d..a4a88e6e3b 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers.cc @@ -523,6 +523,19 @@ RunGraphResponse* InMemoryRunGraphResponse::get_proto() { return nullptr; } +size_t InMemoryRunGraphResponse::num_partition_graphs() const { + return partition_graphs_.size(); +} + +GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) { + return &partition_graphs_[i]; +} + +void InMemoryRunGraphResponse::AddPartitionGraph( + const GraphDef& partition_graph) { + partition_graphs_.push_back(partition_graph); +} + size_t OwnedProtoRunGraphResponse::num_recvs() const { return response_.recv_size(); } @@ -563,6 +576,20 @@ CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() { RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; } +size_t OwnedProtoRunGraphResponse::num_partition_graphs() const { + return response_.partition_graph_size(); +} + +GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) { + return response_.mutable_partition_graph(i); +} + +void OwnedProtoRunGraphResponse::AddPartitionGraph( + const GraphDef& partition_graph) { + GraphDef* graph_def = response_.mutable_partition_graph()->Add(); + *graph_def = partition_graph; +} + NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse( RunGraphResponse* response) : response_(response) {} @@ -609,6 +636,20 @@ RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() { return response_; } +size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const { + return response_->partition_graph_size(); +} + +GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) { + return response_->mutable_partition_graph(i); +} + +void NonOwnedProtoRunGraphResponse::AddPartitionGraph( + const GraphDef& partition_graph) { + GraphDef* graph_def = response_->add_partition_graph(); + *graph_def = partition_graph; +} + MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {} size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); } diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h index f247b50dd5..0e3f5b98cb 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers.h +++ b/tensorflow/core/distributed_runtime/message_wrappers.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb_text.h" @@ -424,6 +425,9 @@ class MutableRunGraphResponseWrapper { // execution, if necessary. virtual StepStats* mutable_step_stats() = 0; virtual CostGraphDef* mutable_cost_graph() = 0; + virtual size_t num_partition_graphs() const = 0; + virtual GraphDef* mutable_partition_graph(size_t i) = 0; + virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0; protected: // Returns a mutable protobuf message that represents the contents of @@ -451,6 +455,9 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { void AddRecv(const string& key, const Tensor& value) override; StepStats* mutable_step_stats() override; CostGraphDef* mutable_cost_graph() override; + size_t num_partition_graphs() const override; + GraphDef* mutable_partition_graph(size_t i) override; + void AddPartitionGraph(const GraphDef& partition_graph) override; protected: // NOTE: This method is not implemented. See @@ -461,6 +468,7 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper { gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_; StepStats step_stats_; CostGraphDef cost_graph_; + std::vector<GraphDef> partition_graphs_; }; // Proto-based message wrapper for use on the client side of the RunGraph RPC. @@ -474,6 +482,9 @@ class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { void AddRecv(const string& key, const Tensor& value) override; StepStats* mutable_step_stats() override; CostGraphDef* mutable_cost_graph() override; + size_t num_partition_graphs() const override; + GraphDef* mutable_partition_graph(size_t i) override; + void AddPartitionGraph(const GraphDef& partition_graph) override; protected: RunGraphResponse* get_proto() override; @@ -495,6 +506,9 @@ class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper { void AddRecv(const string& key, const Tensor& value) override; StepStats* mutable_step_stats() override; CostGraphDef* mutable_cost_graph() override; + size_t num_partition_graphs() const override; + GraphDef* mutable_partition_graph(size_t i) override; + void AddPartitionGraph(const GraphDef& partition_graph) override; protected: RunGraphResponse* get_proto() override; diff --git a/tensorflow/core/distributed_runtime/message_wrappers_test.cc b/tensorflow/core/distributed_runtime/message_wrappers_test.cc index 5b0c2945b9..e87fff151e 100644 --- a/tensorflow/core/distributed_runtime/message_wrappers_test.cc +++ b/tensorflow/core/distributed_runtime/message_wrappers_test.cc @@ -88,6 +88,7 @@ static void CheckRunGraphRequest(const RunGraphRequestWrapper& request) { EXPECT_EQ(13, request.step_id()); EXPECT_FALSE(request.exec_opts().record_costs()); EXPECT_TRUE(request.exec_opts().record_timeline()); + EXPECT_FALSE(request.exec_opts().record_partition_graphs()); EXPECT_EQ(2, request.num_sends()); Tensor val; TF_EXPECT_OK(request.SendValue(0, &val)); @@ -105,6 +106,9 @@ static void BuildRunGraphResponse( run_graph_response->mutable_step_stats()->add_dev_stats()->set_device( "/cpu:0"); run_graph_response->mutable_cost_graph()->add_node()->set_name("cost_node"); + GraphDef graph_def; + graph_def.set_version(1234); + run_graph_response->AddPartitionGraph(graph_def); } static void CheckRunGraphResponse(MutableRunGraphResponseWrapper* response) { @@ -120,6 +124,9 @@ static void CheckRunGraphResponse(MutableRunGraphResponseWrapper* response) { EXPECT_EQ("/cpu:0", response->mutable_step_stats()->dev_stats(0).device()); EXPECT_EQ(1, response->mutable_cost_graph()->node_size()); EXPECT_EQ("cost_node", response->mutable_cost_graph()->node(0).name()); + EXPECT_EQ(1, response->num_partition_graphs()); + GraphDef graph_def; + EXPECT_EQ(1234, response->mutable_partition_graph(0)->version()); } static void BuildRunStepResponse( @@ -131,6 +138,12 @@ static void BuildRunStepResponse( "fetch_y:0", run_graph_response, 1)); *run_step_response->mutable_metadata()->mutable_step_stats() = *run_graph_response->mutable_step_stats(); + protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = + run_step_response->mutable_metadata()->mutable_partition_graphs(); + for (size_t i = 0; i < run_graph_response->num_partition_graphs(); i++) { + partition_graph_defs->Add()->Swap( + run_graph_response->mutable_partition_graph(i)); + } } static void CheckRunStepResponse( @@ -145,6 +158,8 @@ static void CheckRunStepResponse( test::ExpectTensorEqual<int32>(TensorB(), val); EXPECT_EQ(1, response.metadata().step_stats().dev_stats_size()); EXPECT_EQ("/cpu:0", response.metadata().step_stats().dev_stats(0).device()); + EXPECT_EQ(1, response.metadata().partition_graphs_size()); + EXPECT_EQ(1234, response.metadata().partition_graphs(0).version()); } TEST(MessageWrappers, RunStepRequest_Basic) { diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc index 16e450abb0..34b53e965f 100644 --- a/tensorflow/core/distributed_runtime/worker.cc +++ b/tensorflow/core/distributed_runtime/worker.cc @@ -156,10 +156,9 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request, return; } } - CostGraphDef* cost_graph = response->mutable_cost_graph(); session->graph_mgr->ExecuteAsync( request->graph_handle(), step_id, session, request->exec_opts(), - collector, cost_graph, cm, in, + collector, response, cm, in, [this, step_id, response, session, cm, out, token, collector, opts, done](Status s) { if (s.ok()) { @@ -230,7 +229,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts, } session->graph_mgr->ExecuteAsync( graph_handle, step_id, session, request->exec_opts(), - nullptr /* collector */, nullptr /* cost_graph */, cm, in, + nullptr /* collector */, nullptr /* response */, cm, in, [this, token, step_id, cm](Status s) { { mutex_lock l(mu_); diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto index 9d4a417aa3..137f9bc216 100644 --- a/tensorflow/core/protobuf/worker.proto +++ b/tensorflow/core/protobuf/worker.proto @@ -168,6 +168,7 @@ message CleanupAllResponse { message ExecutorOpts { bool record_costs = 1; bool record_timeline = 3; + bool record_partition_graphs = 4; }; message RunGraphRequest { @@ -212,10 +213,12 @@ message RunGraphResponse { // `RunGraphRequest.recv_key`. repeated NamedTensorProto recv = 1; - // If the request asked for execution stats or cost graph, these are returned - // here. + // If the request asked for execution stats, the cost graph, or the partition + // graphs, these are returned here. + // TODO(suharshs): Package these in a RunMetadata instead. StepStats step_stats = 2; CostGraphDef cost_graph = 3; + repeated GraphDef partition_graph = 4; } //////////////////////////////////////////////////////////////////////////////// diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 61d411b6f9..0cec75cf99 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -1508,6 +1508,22 @@ class SessionTest(test_util.TensorFlowTestCase): else: self.assertFalse(run_metadata.HasField('cost_graph')) + def runTestOutputPartitionGraphs(self, sess): + run_options = config_pb2.RunOptions(output_partition_graphs=True) + a = constant_op.constant(1) + run_metadata = config_pb2.RunMetadata() + sess.run(a, options=run_options, run_metadata=run_metadata) + self.assertGreater(len(run_metadata.partition_graphs), 0) + sess.run(a, run_metadata=run_metadata) + self.assertEqual(len(run_metadata.partition_graphs), 0) + + def testOutputPartitionGraphsDirect(self): + self.runTestOutputPartitionGraphs(session.Session()) + + def testOutputPartitionGraphsDistributed(self): + server = server_lib.Server.create_local_server() + self.runTestOutputPartitionGraphs(session.Session(server.target)) + def testNonInteractiveSessionNesting(self): sess1 = session.Session() sess1_controller = sess1.as_default() |