aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2017-07-19 01:17:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-19 01:22:05 -0700
commita555b786bb5ca046ffe2bbb2b1615ab24851e954 (patch)
treee17496deac385af433a90f2cda31b3b0fb645e59
parent9b28e435c1507be52129ec22d9d006d12b1e79f3 (diff)
Add output_partitions support in distributed runtime.
PiperOrigin-RevId: 162456565
-rw-r--r--tensorflow/core/distributed_runtime/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc17
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.h5
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc28
-rw-r--r--tensorflow/core/distributed_runtime/master_session.h1
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.cc41
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h14
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers_test.cc15
-rw-r--r--tensorflow/core/distributed_runtime/worker.cc5
-rw-r--r--tensorflow/core/protobuf/worker.proto7
-rw-r--r--tensorflow/python/client/session_test.py16
11 files changed, 134 insertions, 16 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 89a5f433bd..d9ed40c50a 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -359,6 +359,7 @@ cc_library(
srcs = ["graph_mgr.cc"],
hdrs = ["graph_mgr.h"],
deps = [
+ ":message_wrappers",
":rendezvous_mgr_interface",
":worker_env",
"//tensorflow/core:core_cpu_internal",
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 205843d342..7f77bf8b4e 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -442,10 +442,9 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
}
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
- WorkerSession* session,
- const ExecutorOpts& /*opts*/,
+ WorkerSession* session, const ExecutorOpts& opts,
StepStatsCollector* collector,
- CostGraphDef* cost_graph,
+ MutableRunGraphResponseWrapper* response,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done) {
// Lookup an item. Holds one ref while executing.
@@ -464,6 +463,18 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
return;
}
+ CostGraphDef* cost_graph = nullptr;
+ if (response != nullptr) {
+ cost_graph = response->mutable_cost_graph();
+ if (opts.record_partition_graphs()) {
+ for (const ExecutionUnit& unit : item->units) {
+ GraphDef graph_def;
+ unit.graph->ToGraphDef(&graph_def);
+ response->AddPartitionGraph(graph_def);
+ }
+ }
+ }
+
RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = rendezvous->Initialize(session);
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h
index 4ee3711d02..fb83720d2a 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.h
+++ b/tensorflow/core/distributed_runtime/graph_mgr.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/costmodel_manager.h"
#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
@@ -31,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/protobuf/debug.pb.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
namespace tensorflow {
@@ -80,7 +82,8 @@ class GraphMgr {
typedef std::function<void(const Status&)> StatusCallback;
void ExecuteAsync(const string& handle, const int64 step_id,
WorkerSession* session, const ExecutorOpts& opts,
- StepStatsCollector* collector, CostGraphDef* cost_graph,
+ StepStatsCollector* collector,
+ MutableRunGraphResponseWrapper* response,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done);
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 2b6e0c5268..361e89290d 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -513,6 +513,9 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
if (pss->collect_rpcs) {
SetRPCLogging(true);
}
+ if (pss->collect_partition_graphs) {
+ exec_opts.set_record_partition_graphs(true);
+ }
if (pss->collect_costs || pss->collect_timeline) {
pss->step_stats.resize(partitions_.size());
}
@@ -615,30 +618,39 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
if (status.ok()) {
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
- for (size_t j = 0; j < calls.get(i)->resp->num_recvs(); ++j) {
- auto iter = part.key_fetch.find(calls.get(i)->resp->recv_key(j));
+ MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
+ for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
+ auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
if (iter == part.key_fetch.end()) {
status.Update(errors::Internal("Unexpected fetch key: ",
- calls.get(i)->resp->recv_key(j)));
+ run_graph_resp->recv_key(j)));
break;
}
const string& fetch = iter->second;
- status.Update(resp->AddTensorFromRunGraphResponse(
- fetch, calls.get(i)->resp.get(), j));
+ status.Update(
+ resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
if (!status.ok()) {
break;
}
}
if (pss->collect_timeline) {
- pss->step_stats[i].Swap(calls.get(i)->resp->mutable_step_stats());
+ pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
}
if (pss->collect_costs) {
- CostGraphDef* cost_graph = calls.get(i)->resp->mutable_cost_graph();
+ CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
for (int j = 0; j < cost_graph->node_size(); ++j) {
resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
cost_graph->mutable_node(j));
}
}
+ if (pss->collect_partition_graphs) {
+ protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
+ resp->mutable_metadata()->mutable_partition_graphs();
+ for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
+ partition_graph_defs->Add()->Swap(
+ run_graph_resp->mutable_partition_graph(i));
+ }
+ }
}
}
return status;
@@ -1361,6 +1373,7 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
pss.collect_costs =
build_cost_model_every > 0 &&
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
+ pss.collect_partition_graphs = req.options().output_partition_graphs();
std::unique_ptr<ProfileHandler> ph = run_state->rcg->GetProfileHandler(
run_state->step_id, count, req.options());
@@ -1517,6 +1530,7 @@ Status MasterSession::DoRunWithLocalExecution(
pss.collect_costs =
build_cost_model_every > 0 &&
((count + 1 - build_cost_model_after) % build_cost_model_every == 0);
+ pss.collect_partition_graphs = req.options().output_partition_graphs();
std::unique_ptr<ProfileHandler> ph =
rcg->GetProfileHandler(step_id, count, req.options());
diff --git a/tensorflow/core/distributed_runtime/master_session.h b/tensorflow/core/distributed_runtime/master_session.h
index 10fc4868ca..33b9bfe631 100644
--- a/tensorflow/core/distributed_runtime/master_session.h
+++ b/tensorflow/core/distributed_runtime/master_session.h
@@ -145,6 +145,7 @@ class MasterSession : public core::RefCounted {
bool collect_costs = false;
bool collect_timeline = false;
bool collect_rpcs = false;
+ bool collect_partition_graphs = false;
Microseconds start_micros = Microseconds(0);
Microseconds end_micros = Microseconds(0);
std::vector<StepStats> step_stats; // per partition
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.cc b/tensorflow/core/distributed_runtime/message_wrappers.cc
index b5b564375d..a4a88e6e3b 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.cc
+++ b/tensorflow/core/distributed_runtime/message_wrappers.cc
@@ -523,6 +523,19 @@ RunGraphResponse* InMemoryRunGraphResponse::get_proto() {
return nullptr;
}
+size_t InMemoryRunGraphResponse::num_partition_graphs() const {
+ return partition_graphs_.size();
+}
+
+GraphDef* InMemoryRunGraphResponse::mutable_partition_graph(size_t i) {
+ return &partition_graphs_[i];
+}
+
+void InMemoryRunGraphResponse::AddPartitionGraph(
+ const GraphDef& partition_graph) {
+ partition_graphs_.push_back(partition_graph);
+}
+
size_t OwnedProtoRunGraphResponse::num_recvs() const {
return response_.recv_size();
}
@@ -563,6 +576,20 @@ CostGraphDef* OwnedProtoRunGraphResponse::mutable_cost_graph() {
RunGraphResponse* OwnedProtoRunGraphResponse::get_proto() { return &response_; }
+size_t OwnedProtoRunGraphResponse::num_partition_graphs() const {
+ return response_.partition_graph_size();
+}
+
+GraphDef* OwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
+ return response_.mutable_partition_graph(i);
+}
+
+void OwnedProtoRunGraphResponse::AddPartitionGraph(
+ const GraphDef& partition_graph) {
+ GraphDef* graph_def = response_.mutable_partition_graph()->Add();
+ *graph_def = partition_graph;
+}
+
NonOwnedProtoRunGraphResponse::NonOwnedProtoRunGraphResponse(
RunGraphResponse* response)
: response_(response) {}
@@ -609,6 +636,20 @@ RunGraphResponse* NonOwnedProtoRunGraphResponse::get_proto() {
return response_;
}
+size_t NonOwnedProtoRunGraphResponse::num_partition_graphs() const {
+ return response_->partition_graph_size();
+}
+
+GraphDef* NonOwnedProtoRunGraphResponse::mutable_partition_graph(size_t i) {
+ return response_->mutable_partition_graph(i);
+}
+
+void NonOwnedProtoRunGraphResponse::AddPartitionGraph(
+ const GraphDef& partition_graph) {
+ GraphDef* graph_def = response_->add_partition_graph();
+ *graph_def = partition_graph;
+}
+
MutableRunStepResponseWrapper::~MutableRunStepResponseWrapper() {}
size_t InMemoryRunStepResponse::num_tensors() const { return tensors_.size(); }
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index f247b50dd5..0e3f5b98cb 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/cost_graph.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb_text.h"
@@ -424,6 +425,9 @@ class MutableRunGraphResponseWrapper {
// execution, if necessary.
virtual StepStats* mutable_step_stats() = 0;
virtual CostGraphDef* mutable_cost_graph() = 0;
+ virtual size_t num_partition_graphs() const = 0;
+ virtual GraphDef* mutable_partition_graph(size_t i) = 0;
+ virtual void AddPartitionGraph(const GraphDef& partition_graph) = 0;
protected:
// Returns a mutable protobuf message that represents the contents of
@@ -451,6 +455,9 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
void AddRecv(const string& key, const Tensor& value) override;
StepStats* mutable_step_stats() override;
CostGraphDef* mutable_cost_graph() override;
+ size_t num_partition_graphs() const override;
+ GraphDef* mutable_partition_graph(size_t i) override;
+ void AddPartitionGraph(const GraphDef& partition_graph) override;
protected:
// NOTE: This method is not implemented. See
@@ -461,6 +468,7 @@ class InMemoryRunGraphResponse : public MutableRunGraphResponseWrapper {
gtl::InlinedVector<std::pair<string, Tensor>, 4> recvs_;
StepStats step_stats_;
CostGraphDef cost_graph_;
+ std::vector<GraphDef> partition_graphs_;
};
// Proto-based message wrapper for use on the client side of the RunGraph RPC.
@@ -474,6 +482,9 @@ class OwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
void AddRecv(const string& key, const Tensor& value) override;
StepStats* mutable_step_stats() override;
CostGraphDef* mutable_cost_graph() override;
+ size_t num_partition_graphs() const override;
+ GraphDef* mutable_partition_graph(size_t i) override;
+ void AddPartitionGraph(const GraphDef& partition_graph) override;
protected:
RunGraphResponse* get_proto() override;
@@ -495,6 +506,9 @@ class NonOwnedProtoRunGraphResponse : public MutableRunGraphResponseWrapper {
void AddRecv(const string& key, const Tensor& value) override;
StepStats* mutable_step_stats() override;
CostGraphDef* mutable_cost_graph() override;
+ size_t num_partition_graphs() const override;
+ GraphDef* mutable_partition_graph(size_t i) override;
+ void AddPartitionGraph(const GraphDef& partition_graph) override;
protected:
RunGraphResponse* get_proto() override;
diff --git a/tensorflow/core/distributed_runtime/message_wrappers_test.cc b/tensorflow/core/distributed_runtime/message_wrappers_test.cc
index 5b0c2945b9..e87fff151e 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers_test.cc
+++ b/tensorflow/core/distributed_runtime/message_wrappers_test.cc
@@ -88,6 +88,7 @@ static void CheckRunGraphRequest(const RunGraphRequestWrapper& request) {
EXPECT_EQ(13, request.step_id());
EXPECT_FALSE(request.exec_opts().record_costs());
EXPECT_TRUE(request.exec_opts().record_timeline());
+ EXPECT_FALSE(request.exec_opts().record_partition_graphs());
EXPECT_EQ(2, request.num_sends());
Tensor val;
TF_EXPECT_OK(request.SendValue(0, &val));
@@ -105,6 +106,9 @@ static void BuildRunGraphResponse(
run_graph_response->mutable_step_stats()->add_dev_stats()->set_device(
"/cpu:0");
run_graph_response->mutable_cost_graph()->add_node()->set_name("cost_node");
+ GraphDef graph_def;
+ graph_def.set_version(1234);
+ run_graph_response->AddPartitionGraph(graph_def);
}
static void CheckRunGraphResponse(MutableRunGraphResponseWrapper* response) {
@@ -120,6 +124,9 @@ static void CheckRunGraphResponse(MutableRunGraphResponseWrapper* response) {
EXPECT_EQ("/cpu:0", response->mutable_step_stats()->dev_stats(0).device());
EXPECT_EQ(1, response->mutable_cost_graph()->node_size());
EXPECT_EQ("cost_node", response->mutable_cost_graph()->node(0).name());
+ EXPECT_EQ(1, response->num_partition_graphs());
+ GraphDef graph_def;
+ EXPECT_EQ(1234, response->mutable_partition_graph(0)->version());
}
static void BuildRunStepResponse(
@@ -131,6 +138,12 @@ static void BuildRunStepResponse(
"fetch_y:0", run_graph_response, 1));
*run_step_response->mutable_metadata()->mutable_step_stats() =
*run_graph_response->mutable_step_stats();
+ protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
+ run_step_response->mutable_metadata()->mutable_partition_graphs();
+ for (size_t i = 0; i < run_graph_response->num_partition_graphs(); i++) {
+ partition_graph_defs->Add()->Swap(
+ run_graph_response->mutable_partition_graph(i));
+ }
}
static void CheckRunStepResponse(
@@ -145,6 +158,8 @@ static void CheckRunStepResponse(
test::ExpectTensorEqual<int32>(TensorB(), val);
EXPECT_EQ(1, response.metadata().step_stats().dev_stats_size());
EXPECT_EQ("/cpu:0", response.metadata().step_stats().dev_stats(0).device());
+ EXPECT_EQ(1, response.metadata().partition_graphs_size());
+ EXPECT_EQ(1234, response.metadata().partition_graphs(0).version());
}
TEST(MessageWrappers, RunStepRequest_Basic) {
diff --git a/tensorflow/core/distributed_runtime/worker.cc b/tensorflow/core/distributed_runtime/worker.cc
index 16e450abb0..34b53e965f 100644
--- a/tensorflow/core/distributed_runtime/worker.cc
+++ b/tensorflow/core/distributed_runtime/worker.cc
@@ -156,10 +156,9 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
return;
}
}
- CostGraphDef* cost_graph = response->mutable_cost_graph();
session->graph_mgr->ExecuteAsync(
request->graph_handle(), step_id, session, request->exec_opts(),
- collector, cost_graph, cm, in,
+ collector, response, cm, in,
[this, step_id, response, session, cm, out, token, collector, opts,
done](Status s) {
if (s.ok()) {
@@ -230,7 +229,7 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
}
session->graph_mgr->ExecuteAsync(
graph_handle, step_id, session, request->exec_opts(),
- nullptr /* collector */, nullptr /* cost_graph */, cm, in,
+ nullptr /* collector */, nullptr /* response */, cm, in,
[this, token, step_id, cm](Status s) {
{
mutex_lock l(mu_);
diff --git a/tensorflow/core/protobuf/worker.proto b/tensorflow/core/protobuf/worker.proto
index 9d4a417aa3..137f9bc216 100644
--- a/tensorflow/core/protobuf/worker.proto
+++ b/tensorflow/core/protobuf/worker.proto
@@ -168,6 +168,7 @@ message CleanupAllResponse {
message ExecutorOpts {
bool record_costs = 1;
bool record_timeline = 3;
+ bool record_partition_graphs = 4;
};
message RunGraphRequest {
@@ -212,10 +213,12 @@ message RunGraphResponse {
// `RunGraphRequest.recv_key`.
repeated NamedTensorProto recv = 1;
- // If the request asked for execution stats or cost graph, these are returned
- // here.
+ // If the request asked for execution stats, the cost graph, or the partition
+ // graphs, these are returned here.
+ // TODO(suharshs): Package these in a RunMetadata instead.
StepStats step_stats = 2;
CostGraphDef cost_graph = 3;
+ repeated GraphDef partition_graph = 4;
}
////////////////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index 61d411b6f9..0cec75cf99 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -1508,6 +1508,22 @@ class SessionTest(test_util.TensorFlowTestCase):
else:
self.assertFalse(run_metadata.HasField('cost_graph'))
+ def runTestOutputPartitionGraphs(self, sess):
+ run_options = config_pb2.RunOptions(output_partition_graphs=True)
+ a = constant_op.constant(1)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(a, options=run_options, run_metadata=run_metadata)
+ self.assertGreater(len(run_metadata.partition_graphs), 0)
+ sess.run(a, run_metadata=run_metadata)
+ self.assertEqual(len(run_metadata.partition_graphs), 0)
+
+ def testOutputPartitionGraphsDirect(self):
+ self.runTestOutputPartitionGraphs(session.Session())
+
+ def testOutputPartitionGraphsDistributed(self):
+ server = server_lib.Server.create_local_server()
+ self.runTestOutputPartitionGraphs(session.Session(server.target))
+
def testNonInteractiveSessionNesting(self):
sess1 = session.Session()
sess1_controller = sess1.as_default()