aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-22 05:15:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-22 05:19:06 -0700
commitca552d54ac67be8837aeabdb43269846d9df4eb5 (patch)
tree11d592685766ab64187b520d91d7dfa2b6f231fc /tensorflow/core/grappler
parente317152dad1aa66bc493abc046a60dbbf650de92 (diff)
Add PinToHostOptimizer to grappler: force small ops to happen on CPU (instead of
GPU). This avoids many unnecessary CPU<->GPU memcpy and syncs. PiperOrigin-RevId: 214108484
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc1
-rw-r--r--tensorflow/core/grappler/graph_view.cc30
-rw-r--r--tensorflow/core/grappler/graph_view.h10
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc54
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD39
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc6
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc218
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h62
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc162
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc9
10 files changed, 588 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 7171ae059b..3b1d7d8347 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) {
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
rewriter_config->set_shape_optimization(RewriterConfig::OFF);
rewriter_config->set_remapping(RewriterConfig::OFF);
+ rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF);
rewriter_config->mutable_auto_parallel()->set_enable(false);
rewriter_config->clear_optimizers();
} else {
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index a6b6b6f8b2..b8d8243174 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -14,11 +14,41 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
+ ++output_arg_id) {
+ if (port_id < 0) {
+ return -1;
+ } else if (port_id == 0) {
+ return output_arg_id;
+ }
+
+ const auto& output_arg = op.output_arg(output_arg_id);
+ if (!output_arg.number_attr().empty()) {
+ const int n = node.attr().at(output_arg.number_attr()).i();
+ if (n < 0) {
+ // This should never happen.
+ DCHECK_GE(n, 0);
+ return -1;
+ }
+ if (port_id < n) {
+ return output_arg_id;
+ }
+ port_id -= n;
+ } else {
+ --port_id;
+ }
+ }
+
+ return -1;
+}
+
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ac260f85a0..ec946ca3b5 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -20,11 +20,21 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
+// Map a node/op's output port_id to arg_id.
+//
+// The port_id refers to the n-th tensor of the node, while the arg_id refers to
+// the n-th arg of the op. These two can be different if an op's arg is a list
+// of tensors.
+//
+// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
public:
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 958eb921fb..30512d9d47 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -25,6 +25,60 @@ namespace {
class GraphViewTest : public ::testing::Test {};
+TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& a_node_def = *graph_view.GetNode("a");
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+
+ const OpDef* a_op_def = nullptr;
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok());
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
+}
+
+TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
+ for (int num_splits : {1, 2}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
+ ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits);
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
+ int arg_id = -1;
+ if (port_id < num_splits * 3) {
+ arg_id = port_id / num_splits;
+ }
+ EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id));
+ }
+ }
+}
+
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 261dee4382..960d1addb3 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -518,6 +518,7 @@ cc_library(
":loop_optimizer",
":memory_optimizer",
":model_pruner",
+ ":pin_to_host_optimizer",
":remapper",
":scoped_allocator_optimizer",
":shape_optimizer",
@@ -883,3 +884,41 @@ tf_cc_test(
"//tensorflow/core/grappler/utils:grappler_test",
],
)
+
+cc_library(
+ name = "pin_to_host_optimizer",
+ srcs = ["pin_to_host_optimizer.cc"],
+ hdrs = [
+ "pin_to_host_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "pin_to_host_optimizer_test",
+ srcs = ["pin_to_host_optimizer_test.cc"],
+ deps = [
+ ":pin_to_host_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index 4b0cbfaa82..3da7a72e80 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
@@ -105,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
+ MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -133,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
+ if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
+ optimizers->push_back(MakeUnique<PinToHostOptimizer>());
+ }
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->push_back(
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
@@ -468,6 +473,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
+ cfg.pin_to_host_optimization() == RewriterConfig::ON ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
new file mode 100644
index 0000000000..8a65cd3ec3
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -0,0 +1,218 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+
+// TODO(williamchan): Change this constant to be something smarter, maybe
+// dynamically determined.
+constexpr int64 kTensorMaxSize = 64;
+
+// Find KernelDef for `node`.
+Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
+ // Try find KernelDef for node.device, else GPU or CPU.
+ for (const DeviceType& device :
+ {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
+ Status s = FindKernelDef(device, node, kdef, nullptr);
+ if (s.ok()) {
+ return Status::OK();
+ }
+ }
+
+ return errors::NotFound("Could not find KernelDef for op: ", node.op());
+}
+
+// Check if all node's inputs are pinned to CPU memory.
+bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
+ // Loop through all the inputs excluding the controlling nodes.
+ for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
+ // Check if (the fanin) op's device is on CPU.
+ if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
+ continue;
+ }
+
+ // Check if (the fanin) op's output port is pinned to HostMemory.
+ const OpDef* fanin_odef = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
+ return false;
+ }
+
+ const int output_arg_id =
+ OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
+ if (output_arg_id < 0) {
+ LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
+ << node.DebugString() << "\n"
+ << fanin_odef->DebugString();
+ return false;
+ }
+
+ const KernelDef* fanin_kdef = nullptr;
+ s = TryFindKernelDef(*fanin.node, &fanin_kdef);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
+ return false;
+ }
+
+ bool fanin_pinned = false;
+ for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
+ if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
+ fanin_pinned = true;
+ break;
+ }
+ }
+
+ if (!fanin_pinned) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
+ // Check if Tensor is integer and small size.
+
+ // Check type to be int32 or int64.
+ if (prop.dtype() != DataType::DT_INT32 &&
+ prop.dtype() != DataType::DT_INT64) {
+ return false;
+ }
+
+ // Check size known and small.
+ const int64 size = NumCoefficients(prop.shape());
+ if (size < 0 || size > kTensorMaxSize) {
+ return false;
+ }
+
+ return true;
+}
+
+bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
+ const NodeDef& node) {
+ for (const auto& prop : properties.GetInputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return false;
+ }
+ }
+
+ for (const auto& prop : properties.GetOutputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device) {
+ // Force this node onto the CPU.
+ if (device.empty() && has_device_cpu) {
+ return "/device:CPU:0";
+ } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ // Sometimes the cluster can have:
+ // devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
+ // and we need to handle them properly.
+ for (const auto& device_match :
+ {std::pair<string, string>("GPU", "CPU:0"),
+ std::pair<string, string>("/device", "/device:CPU:0")}) {
+ const string device_host =
+ strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ device_match.second);
+ if (devices.find(device_host) != devices.end()) {
+ return device_host;
+ }
+ }
+ }
+
+ // We couldn't find an appropriate Host device, return original device.
+ return device;
+}
+} // end namespace internal
+
+Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+
+ GraphProperties properties(item);
+ bool has_properties = false;
+ GraphView graph(optimized_graph);
+
+ gtl::FlatSet<string> devices;
+ if (cluster) {
+ const std::vector<string> device_names = cluster->GetDeviceNames();
+ devices.insert(device_names.begin(), device_names.end());
+ } else {
+ devices = {"/device:CPU:0"};
+ }
+
+ const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
+
+ // Topologically sort the graph, so that we traverse the nodes in order. This
+ // will help us discover producer->consumer chains of Host ops.
+ TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+ for (auto& node : *optimized_graph->mutable_node()) {
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ continue;
+ }
+
+ // Check the node can be run on CPU.
+ Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
+ if (!s.ok()) {
+ continue;
+ }
+
+ // Check all input's are pinned to CPU.
+ if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
+ continue;
+ }
+
+ if (!has_properties) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ has_properties = true;
+ }
+
+ // Check all inputs and outputs are integers and small.
+ if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
+ continue;
+ }
+
+ // Try and swap the device to Host.
+ node.set_device(
+ internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
+ }
+ return Status::OK();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
new file mode 100644
index 0000000000..d557a03463
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+
+#include <unordered_set>
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+// Try and find an appropriate Host device in `devices` given `device`.
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device);
+} // end namespace internal
+
+// Optimize TensorFlow ops that should be swapped into the CPU to avoid
+// excessive cpu<->gpu memcpy/sync.
+//
+// TODO(williamchan): The current heuristic will swap any small integer Const to
+// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of
+// gpu->gpu->gpu may have been better/faster. We should probably fix this.
+class PinToHostOptimizer : public GraphOptimizer {
+ public:
+ PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {}
+ explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level)
+ : opt_level_(opt_level) {}
+
+ ~PinToHostOptimizer() override {}
+
+ string name() const override { return "pin_to_host_optimizer"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ RewriterConfig::Toggle opt_level_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
new file mode 100644
index 0000000000..339ddfd1b5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -0,0 +1,162 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class PinToHostOptimizerTest : public GrapplerTest {};
+
+TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+ gtl::FlatSet<string> devices = {};
+ EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
+
+ devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
+ "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
+ "/device:CPU:0");
+
+ devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_CPU:0");
+
+ devices = {"/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_GPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_GPU:*");
+}
+
+TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, TopologicalSort) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ // Reverse the graph, and hence rely on the optimizer to sort it.
+ std::reverse(item.graph.mutable_node()->begin(),
+ item.graph.mutable_node()->end());
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GrapplerItem item;
+ item.fetch = {"a", "b"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ ++found;
+ }
+ EXPECT_EQ(found, 2);
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 910b0acaef..6266733f3e 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() {
// optimizations interfering in the comparison.
RewriterConfig* cfg =
options_.config.mutable_graph_options()->mutable_rewrite_options();
- cfg->set_constant_folding(RewriterConfig::OFF);
+ // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
+ // off.
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
+ cfg->set_constant_folding(RewriterConfig::OFF);
+ cfg->set_debug_stripper(RewriterConfig::OFF);
cfg->set_dependency_optimization(RewriterConfig::OFF);
- cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_function_optimization(RewriterConfig::OFF);
cfg->set_layout_optimizer(RewriterConfig::OFF);
- cfg->set_debug_stripper(RewriterConfig::OFF);
+ cfg->set_loop_optimization(RewriterConfig::OFF);
+ cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(