diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 03:54:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 03:58:27 -0700 |
commit | a9d0bf9afc323be9ca52e1a23c52c3238a9b17cf (patch) | |
tree | 3f14792edc63fb081ab3c1903ec859f120469107 /tensorflow/core/grappler | |
parent | 1cb8437a46f1da7717ebc41ee29a74c305266ec6 (diff) |
Swap Const ops back to GPU greedily.
PiperOrigin-RevId: 214415906
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc | 25 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc | 32 |
2 files changed, 56 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 98c27300a9..2190d38937 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -71,6 +71,7 @@ bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) { if (output_arg_id < 0) { LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n" << node.DebugString() << "\n" + << fanin.node->DebugString() << "\n" << fanin_odef->DebugString(); return false; } @@ -158,7 +159,7 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices, } bool IsTPUGraphDef(const GraphDef& def) { - for (auto node : def.node()) { + for (const auto& node : def.node()) { if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || node.op() == "TPUPartitionedCall") { return true; @@ -197,6 +198,10 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Topologically sort the graph, so that we traverse the nodes in order. This // will help us discover producer->consumer chains of Host ops. TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph)); + + // All the Const nodes, and their original devices in topological order. + std::vector<std::pair<NodeDef*, string>> const_nodes; + for (auto& node : *optimized_graph->mutable_node()) { // Check if node already on CPU. if (str_util::StrContains(node.device(), DEVICE_CPU)) { @@ -230,10 +235,28 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, continue; } + if (IsConstant(node)) { + const_nodes.emplace_back(&node, node.device()); + } // Try and swap the device to Host. node.set_device( internal::TryFindHostDevice(devices, has_device_cpu, node.device())); } + + // Traverse all `const_nodes`, and map them back to GPU greedily. + for (auto& it : const_nodes) { + NodeDef* node = it.first; + const string& device = it.second; + + // Check all the consumers of this node, if any of them are on the original + // device, swap this node back onto the original device. + for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { + if (fanout.node->device() == device) { + node->set_device(device); + break; + } + } + } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 339ddfd1b5..173cb3fe3c 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -128,6 +128,38 @@ TEST_F(PinToHostOptimizerTest, TopologicalSort) { EXPECT_EQ(found, 4); } +TEST_F(PinToHostOptimizerTest, NoSwap) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // `b` should be too big to swap, consequently `c` should not be swapped. + // PinToHostOptimizer should then detect that `a` should not be swapped. + Output a = ops::Const(s.WithOpName("a"), 1, {1, 1}); + Output b = ops::Const(s.WithOpName("b"), 1, {1, 1024 * 1024}); + Output c = ops::MatMul(s.WithOpName("c"), a, b); + + GrapplerItem item; + item.fetch = {"a", "b", "c"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + auto tensors = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); ++i) { + test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]); + } + + int found = 0; + for (const NodeDef& node : output.node()) { + EXPECT_TRUE(node.device().empty()); + ++found; + } + EXPECT_EQ(found, 3); +} + TEST_F(PinToHostOptimizerTest, PortIdToArgId) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3}); |