aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuan Yu <yuanbyu@google.com>2016-12-15 10:46:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-15 11:03:47 -0800
commit0a5d4db87c5cba5731ad265ad94e7a3a93c57c46 (patch)
tree412135b7e3c3e31abb4585a08a22df09a9ae28ab
parent0ee36e483560f86795bd52eccff34d85397bae3c (diff)
Automated rollback of change 141675118
Change: 142160746
-rw-r--r--tensorflow/core/graph/graph_partition.cc64
-rw-r--r--tensorflow/core/graph/graph_partition_test.cc70
2 files changed, 5 insertions, 129 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index 5f5e6a4b53..8e9eceb699 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -74,28 +74,6 @@ struct RecvInfo {
typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq>
DupRecvTable;
-struct DupControlKey {
- int dst_node_id; // Edge's dst node id
- GraphDef* src_graph; // Edge's src node is in this subgraph
-};
-
-struct DupControlKeyHash {
- size_t operator()(const DupControlKey& k) const {
- return Hash64(reinterpret_cast<const char*>(&k.src_graph),
- sizeof(k.src_graph), k.dst_node_id);
- }
-};
-
-struct DupControlKeyEq {
- bool operator()(const DupControlKey& x, const DupControlKey& y) const {
- return (x.dst_node_id == y.dst_node_id) && (x.src_graph == y.src_graph);
- }
-};
-
-typedef std::unordered_map<DupControlKey, NodeDef*, DupControlKeyHash,
- DupControlKeyEq>
- DupControlTable;
-
struct PairIntHash {
public:
std::size_t operator()(const std::pair<int, int>& x) const {
@@ -114,7 +92,6 @@ struct GraphInfo {
MemoryTypeMap input_types;
MemoryTypeMap output_types;
std::vector<ControlFlowInfo> cf_info;
- std::vector<int32> num_outgoing_control_edges;
};
DataType EdgeType(const Edge* e) {
@@ -540,12 +517,11 @@ Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
// Build memory and device type info for every node in the graph.
// TODO(yuanbyu): It might be simpler if we convert MemoryType to
// DeviceType for the inputs/outputs of each node.
-Status BuildGraphInfo(const Graph& g, GraphInfo* info) {
- info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
- info->num_outgoing_control_edges.resize(g.num_node_ids());
-
+Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
MemoryTypeVector input_memory_types;
MemoryTypeVector output_memory_types;
+
+ info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
for (const Node* node : g.nodes()) {
if (!node->IsOp()) continue; // Skip Sink/Source nodes.
@@ -568,14 +544,6 @@ Status BuildGraphInfo(const Graph& g, GraphInfo* info) {
for (size_t i = 0; i < output_memory_types.size(); ++i) {
info->output_types[{node_id, i}] = output_memory_types[i];
}
-
- int32 num_control_edges = 0;
- for (const Edge* edge : node->out_edges()) {
- if (edge->IsControlEdge()) {
- ++num_control_edges;
- }
- }
- info->num_outgoing_control_edges[node_id] = num_control_edges;
}
return Status::OK();
}
@@ -851,13 +819,12 @@ Status Partition(const PartitionOptions& opts, Graph* g,
// At this point, all the graph mutations have been done. Build memory
// and device type info for every node and edge in the graph.
- status = BuildGraphInfo(*g, &g_info);
+ status = BuildMemoryDeviceInfo(*g, &g_info);
if (!status.ok()) return status;
string dstp;
std::vector<const Edge*> inputs;
DupRecvTable dup_recv(3);
- DupControlTable dup_control(3);
// For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
// edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
// edge to dst. We will add a control edge for every pair in
@@ -951,9 +918,7 @@ Status Partition(const PartitionOptions& opts, Graph* g,
}
// Check whether there is already a send/recv pair transferring
- // the same tensor/control from src to the dst partition. This
- // handles the dedup case when a single source in one partition
- // going to multiple destinations in another partition.
+ // the same tensor/control from the src to dst partition.
const bool on_host = IsDstInputOnHost(edge, g_info);
DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
auto iter = dup_recv.find(key);
@@ -978,22 +943,6 @@ Status Partition(const PartitionOptions& opts, Graph* g,
NodeDefBuilder::NodeOut send_from;
if (edge->IsControlEdge()) {
- DupControlKey key{dst->id(), src_graph};
- int32 num_control_edges = g_info.num_outgoing_control_edges[src->id()];
- if (num_control_edges == 1) {
- // Handle dedup of multiple control edges going from one partition
- // to a single destination in another partition.
- // Note: We require that src has only one outgoing control edge.
- // This is to avoid non-equivalent changes to the graph when
- // combinded with dedup of single-source-multi-destination.
- auto iter = dup_control.find(key);
- if (iter != dup_control.end()) {
- // Note: This may cause start_time(src) > start_time(iter->second).
- AddInput(iter->second, src->name(), Graph::kControlSlot);
- continue;
- }
- }
-
// Insert a dummy const node that will generate a tiny
// data element to be sent from send to recv.
VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
@@ -1007,9 +956,6 @@ Status Partition(const PartitionOptions& opts, Graph* g,
}
AddInput(dummy, src->name(), Graph::kControlSlot);
send_from.Reset(dummy->name(), 0, DT_FLOAT);
- if (num_control_edges == 1) {
- dup_control[key] = dummy;
- }
} else {
send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
}
diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc
index 3ab6b0ae36..d8322e6077 100644
--- a/tensorflow/core/graph/graph_partition_test.cc
+++ b/tensorflow/core/graph/graph_partition_test.cc
@@ -398,75 +398,5 @@ TEST_F(GraphPartitionTest, PartitionIncompleteGraph) {
EXPECT_EQ(error::INVALID_ARGUMENT, status.code()) << status;
}
-TEST_F(GraphPartitionTest, CrossDevice_MultiControl) {
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
- auto a1 = Input(in_.WithOpName("A1"));
- auto a2 = Input(in_.WithOpName("A2"));
- auto b1 = Input(in_.WithOpName("B1"));
- Combine(
- in_.WithOpName("B2").WithControlDependencies(a1).WithControlDependencies(
- a2),
- b1, b1);
-
- Partition(ToGraphDef(), &partitions_);
- EXPECT_EQ(2, partitions_.size());
-
- string a = "/job:a/replica:0/task:0/cpu:0";
- string b = "/job:a/replica:0/task:0/cpu:1";
- a1 = Input(scope_a_.WithOpName("A1"));
- a2 = Input(scope_a_.WithOpName("A2"));
- auto c = Const(scope_a_.WithOpName("A1/_0")
- .WithControlDependencies(a1)
- .WithControlDependencies(a2),
- {});
- _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b);
- ExpectMatchA();
-
- auto recv =
- _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
- auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
- b1 = Input(scope_b_.WithOpName("B1"));
- Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1);
- ExpectMatchB();
-}
-
-TEST_F(GraphPartitionTest, CrossDevice_MultiControl_NoCombine) {
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
- auto a1 = Input(in_.WithOpName("A1"));
- auto a2 = Input(in_.WithOpName("A2"));
- auto b1 = Input(in_.WithOpName("B1"));
- Combine(
- in_.WithOpName("B2").WithControlDependencies(a1).WithControlDependencies(
- a2),
- b1, b1);
- Combine(in_.WithOpName("B3").WithControlDependencies(a1), b1, b1);
-
- Partition(ToGraphDef(), &partitions_);
- EXPECT_EQ(2, partitions_.size());
-
- string a = "/job:a/replica:0/task:0/cpu:0";
- string b = "/job:a/replica:0/task:0/cpu:1";
- a1 = Input(scope_a_.WithOpName("A1"));
- a2 = Input(scope_a_.WithOpName("A2"));
- auto c1 = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {});
- _Send(scope_a_.WithOpName("A1/_1"), c1, "edge_3_A1", a, 82, b);
- auto c2 = Const(scope_a_.WithOpName("A2/_4").WithControlDependencies(a2), {});
- _Send(scope_a_.WithOpName("A2/_5"), c2, "edge_7_A2", a, 82, b);
- ExpectMatchA();
-
- auto recv =
- _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b);
- auto id = Identity(scope_b_.WithOpName("A1/_3"), recv);
- auto recv1 =
- _Recv(scope_b_.WithOpName("A2/_6"), DT_FLOAT, "edge_7_A2", a, 82, b);
- auto id1 = Identity(scope_b_.WithOpName("A2/_7"), recv1);
- b1 = Input(scope_b_.WithOpName("B1"));
- Combine(scope_b_.WithOpName("B2")
- .WithControlDependencies(id)
- .WithControlDependencies(id1),
- b1, b1);
- Combine(scope_b_.WithOpName("B3").WithControlDependencies(id), b1, b1);
- ExpectMatchB();
-}
} // namespace
} // namespace tensorflow