aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-25 03:54:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 03:58:27 -0700
commita9d0bf9afc323be9ca52e1a23c52c3238a9b17cf (patch)
tree3f14792edc63fb081ab3c1903ec859f120469107 /tensorflow/core/grappler
parent1cb8437a46f1da7717ebc41ee29a74c305266ec6 (diff)
Swap Const ops back to GPU greedily.
PiperOrigin-RevId: 214415906
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc25
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc32
2 files changed, 56 insertions, 1 deletions
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 98c27300a9..2190d38937 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -71,6 +71,7 @@ bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
if (output_arg_id < 0) {
LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
<< node.DebugString() << "\n"
+ << fanin.node->DebugString() << "\n"
<< fanin_odef->DebugString();
return false;
}
@@ -158,7 +159,7 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices,
}
bool IsTPUGraphDef(const GraphDef& def) {
- for (auto node : def.node()) {
+ for (const auto& node : def.node()) {
if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
node.op() == "TPUPartitionedCall") {
return true;
@@ -197,6 +198,10 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Topologically sort the graph, so that we traverse the nodes in order. This
// will help us discover producer->consumer chains of Host ops.
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+
+ // All the Const nodes, and their original devices in topological order.
+ std::vector<std::pair<NodeDef*, string>> const_nodes;
+
for (auto& node : *optimized_graph->mutable_node()) {
// Check if node already on CPU.
if (str_util::StrContains(node.device(), DEVICE_CPU)) {
@@ -230,10 +235,28 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
continue;
}
+ if (IsConstant(node)) {
+ const_nodes.emplace_back(&node, node.device());
+ }
// Try and swap the device to Host.
node.set_device(
internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
+
+ // Traverse all `const_nodes`, and map them back to GPU greedily.
+ for (auto& it : const_nodes) {
+ NodeDef* node = it.first;
+ const string& device = it.second;
+
+ // Check all the consumers of this node, if any of them are on the original
+ // device, swap this node back onto the original device.
+ for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
+ if (fanout.node->device() == device) {
+ node->set_device(device);
+ break;
+ }
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 339ddfd1b5..173cb3fe3c 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -128,6 +128,38 @@ TEST_F(PinToHostOptimizerTest, TopologicalSort) {
EXPECT_EQ(found, 4);
}
+TEST_F(PinToHostOptimizerTest, NoSwap) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `b` should be too big to swap, consequently `c` should not be swapped.
+ // PinToHostOptimizer should then detect that `a` should not be swapped.
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 1});
+ Output b = ops::Const(s.WithOpName("b"), 1, {1, 1024 * 1024});
+ Output c = ops::MatMul(s.WithOpName("c"), a, b);
+
+ GrapplerItem item;
+ item.fetch = {"a", "b", "c"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_TRUE(node.device().empty());
+ ++found;
+ }
+ EXPECT_EQ(found, 3);
+}
+
TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});