diff options
-rw-r--r-- | tensorflow/core/graph/graph_partition.cc | 64 | ||||
-rw-r--r-- | tensorflow/core/graph/graph_partition_test.cc | 70 |
2 files changed, 5 insertions, 129 deletions
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index 5f5e6a4b53..8e9eceb699 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -74,28 +74,6 @@ struct RecvInfo { typedef std::unordered_map<DupRecvKey, RecvInfo, DupRecvKeyHash, DupRecvKeyEq> DupRecvTable; -struct DupControlKey { - int dst_node_id; // Edge's dst node id - GraphDef* src_graph; // Edge's src node is in this subgraph -}; - -struct DupControlKeyHash { - size_t operator()(const DupControlKey& k) const { - return Hash64(reinterpret_cast<const char*>(&k.src_graph), - sizeof(k.src_graph), k.dst_node_id); - } -}; - -struct DupControlKeyEq { - bool operator()(const DupControlKey& x, const DupControlKey& y) const { - return (x.dst_node_id == y.dst_node_id) && (x.src_graph == y.src_graph); - } -}; - -typedef std::unordered_map<DupControlKey, NodeDef*, DupControlKeyHash, - DupControlKeyEq> - DupControlTable; - struct PairIntHash { public: std::size_t operator()(const std::pair<int, int>& x) const { @@ -114,7 +92,6 @@ struct GraphInfo { MemoryTypeMap input_types; MemoryTypeMap output_types; std::vector<ControlFlowInfo> cf_info; - std::vector<int32> num_outgoing_control_edges; }; DataType EdgeType(const Edge* e) { @@ -540,12 +517,11 @@ Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, // Build memory and device type info for every node in the graph. // TODO(yuanbyu): It might be simpler if we convert MemoryType to // DeviceType for the inputs/outputs of each node. -Status BuildGraphInfo(const Graph& g, GraphInfo* info) { - info->device_types.resize(g.num_node_ids(), DEVICE_CPU); - info->num_outgoing_control_edges.resize(g.num_node_ids()); - +Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { MemoryTypeVector input_memory_types; MemoryTypeVector output_memory_types; + + info->device_types.resize(g.num_node_ids(), DEVICE_CPU); for (const Node* node : g.nodes()) { if (!node->IsOp()) continue; // Skip Sink/Source nodes. @@ -568,14 +544,6 @@ Status BuildGraphInfo(const Graph& g, GraphInfo* info) { for (size_t i = 0; i < output_memory_types.size(); ++i) { info->output_types[{node_id, i}] = output_memory_types[i]; } - - int32 num_control_edges = 0; - for (const Edge* edge : node->out_edges()) { - if (edge->IsControlEdge()) { - ++num_control_edges; - } - } - info->num_outgoing_control_edges[node_id] = num_control_edges; } return Status::OK(); } @@ -851,13 +819,12 @@ Status Partition(const PartitionOptions& opts, Graph* g, // At this point, all the graph mutations have been done. Build memory // and device type info for every node and edge in the graph. - status = BuildGraphInfo(*g, &g_info); + status = BuildMemoryDeviceInfo(*g, &g_info); if (!status.ok()) return status; string dstp; std::vector<const Edge*> inputs; DupRecvTable dup_recv(3); - DupControlTable dup_control(3); // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref // edge to dst. We will add a control edge for every pair in @@ -951,9 +918,7 @@ Status Partition(const PartitionOptions& opts, Graph* g, } // Check whether there is already a send/recv pair transferring - // the same tensor/control from src to the dst partition. This - // handles the dedup case when a single source in one partition - // going to multiple destinations in another partition. + // the same tensor/control from the src to dst partition. const bool on_host = IsDstInputOnHost(edge, g_info); DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host}; auto iter = dup_recv.find(key); @@ -978,22 +943,6 @@ Status Partition(const PartitionOptions& opts, Graph* g, NodeDefBuilder::NodeOut send_from; if (edge->IsControlEdge()) { - DupControlKey key{dst->id(), src_graph}; - int32 num_control_edges = g_info.num_outgoing_control_edges[src->id()]; - if (num_control_edges == 1) { - // Handle dedup of multiple control edges going from one partition - // to a single destination in another partition. - // Note: We require that src has only one outgoing control edge. - // This is to avoid non-equivalent changes to the graph when - // combinded with dedup of single-source-multi-destination. - auto iter = dup_control.find(key); - if (iter != dup_control.end()) { - // Note: This may cause start_time(src) > start_time(iter->second). - AddInput(iter->second, src->name(), Graph::kControlSlot); - continue; - } - } - // Insert a dummy const node that will generate a tiny // data element to be sent from send to recv. VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "[" @@ -1007,9 +956,6 @@ Status Partition(const PartitionOptions& opts, Graph* g, } AddInput(dummy, src->name(), Graph::kControlSlot); send_from.Reset(dummy->name(), 0, DT_FLOAT); - if (num_control_edges == 1) { - dup_control[key] = dummy; - } } else { send_from.Reset(src->name(), edge->src_output(), EdgeType(edge)); } diff --git a/tensorflow/core/graph/graph_partition_test.cc b/tensorflow/core/graph/graph_partition_test.cc index 3ab6b0ae36..d8322e6077 100644 --- a/tensorflow/core/graph/graph_partition_test.cc +++ b/tensorflow/core/graph/graph_partition_test.cc @@ -398,75 +398,5 @@ TEST_F(GraphPartitionTest, PartitionIncompleteGraph) { EXPECT_EQ(error::INVALID_ARGUMENT, status.code()) << status; } -TEST_F(GraphPartitionTest, CrossDevice_MultiControl) { - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - auto a1 = Input(in_.WithOpName("A1")); - auto a2 = Input(in_.WithOpName("A2")); - auto b1 = Input(in_.WithOpName("B1")); - Combine( - in_.WithOpName("B2").WithControlDependencies(a1).WithControlDependencies( - a2), - b1, b1); - - Partition(ToGraphDef(), &partitions_); - EXPECT_EQ(2, partitions_.size()); - - string a = "/job:a/replica:0/task:0/cpu:0"; - string b = "/job:a/replica:0/task:0/cpu:1"; - a1 = Input(scope_a_.WithOpName("A1")); - a2 = Input(scope_a_.WithOpName("A2")); - auto c = Const(scope_a_.WithOpName("A1/_0") - .WithControlDependencies(a1) - .WithControlDependencies(a2), - {}); - _Send(scope_a_.WithOpName("A1/_1"), c, "edge_3_A1", a, 82, b); - ExpectMatchA(); - - auto recv = - _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b); - auto id = Identity(scope_b_.WithOpName("A1/_3"), recv); - b1 = Input(scope_b_.WithOpName("B1")); - Combine(scope_b_.WithOpName("B2").WithControlDependencies(id), b1, b1); - ExpectMatchB(); -} - -TEST_F(GraphPartitionTest, CrossDevice_MultiControl_NoCombine) { - using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - auto a1 = Input(in_.WithOpName("A1")); - auto a2 = Input(in_.WithOpName("A2")); - auto b1 = Input(in_.WithOpName("B1")); - Combine( - in_.WithOpName("B2").WithControlDependencies(a1).WithControlDependencies( - a2), - b1, b1); - Combine(in_.WithOpName("B3").WithControlDependencies(a1), b1, b1); - - Partition(ToGraphDef(), &partitions_); - EXPECT_EQ(2, partitions_.size()); - - string a = "/job:a/replica:0/task:0/cpu:0"; - string b = "/job:a/replica:0/task:0/cpu:1"; - a1 = Input(scope_a_.WithOpName("A1")); - a2 = Input(scope_a_.WithOpName("A2")); - auto c1 = Const(scope_a_.WithOpName("A1/_0").WithControlDependencies(a1), {}); - _Send(scope_a_.WithOpName("A1/_1"), c1, "edge_3_A1", a, 82, b); - auto c2 = Const(scope_a_.WithOpName("A2/_4").WithControlDependencies(a2), {}); - _Send(scope_a_.WithOpName("A2/_5"), c2, "edge_7_A2", a, 82, b); - ExpectMatchA(); - - auto recv = - _Recv(scope_b_.WithOpName("A1/_2"), DT_FLOAT, "edge_3_A1", a, 82, b); - auto id = Identity(scope_b_.WithOpName("A1/_3"), recv); - auto recv1 = - _Recv(scope_b_.WithOpName("A2/_6"), DT_FLOAT, "edge_7_A2", a, 82, b); - auto id1 = Identity(scope_b_.WithOpName("A2/_7"), recv1); - b1 = Input(scope_b_.WithOpName("B1")); - Combine(scope_b_.WithOpName("B2") - .WithControlDependencies(id) - .WithControlDependencies(id1), - b1, b1); - Combine(scope_b_.WithOpName("B3").WithControlDependencies(id), b1, b1); - ExpectMatchB(); -} } // namespace } // namespace tensorflow |