diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-22 05:15:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-22 05:19:06 -0700 |
commit | ca552d54ac67be8837aeabdb43269846d9df4eb5 (patch) | |
tree | 11d592685766ab64187b520d91d7dfa2b6f231fc /tensorflow/core/grappler | |
parent | e317152dad1aa66bc493abc046a60dbbf650de92 (diff) |
Add PinToHostOptimizer to grappler: force small ops to happen on CPU (instead of
GPU). This avoids many unnecessary CPU<->GPU memcpy and syncs.
PiperOrigin-RevId: 214108484
Diffstat (limited to 'tensorflow/core/grappler')
-rw-r--r-- | tensorflow/core/grappler/clusters/cluster.cc | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/graph_view.cc | 30 | ||||
-rw-r--r-- | tensorflow/core/grappler/graph_view.h | 10 | ||||
-rw-r--r-- | tensorflow/core/grappler/graph_view_test.cc | 54 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 39 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/meta_optimizer.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc | 218 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h | 62 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc | 162 | ||||
-rw-r--r-- | tensorflow/core/grappler/utils/grappler_test.cc | 9 |
10 files changed, 588 insertions, 3 deletions
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index 7171ae059b..3b1d7d8347 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) { rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT); rewriter_config->set_shape_optimization(RewriterConfig::OFF); rewriter_config->set_remapping(RewriterConfig::OFF); + rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF); rewriter_config->mutable_auto_parallel()->set_enable(false); rewriter_config->clear_optimizers(); } else { diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index a6b6b6f8b2..b8d8243174 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -14,11 +14,41 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/grappler/utils.h" namespace tensorflow { namespace grappler { +int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { + for (int output_arg_id = 0; output_arg_id < op.output_arg_size(); + ++output_arg_id) { + if (port_id < 0) { + return -1; + } else if (port_id == 0) { + return output_arg_id; + } + + const auto& output_arg = op.output_arg(output_arg_id); + if (!output_arg.number_attr().empty()) { + const int n = node.attr().at(output_arg.number_attr()).i(); + if (n < 0) { + // This should never happen. + DCHECK_GE(n, 0); + return -1; + } + if (port_id < n) { + return output_arg_id; + } + port_id -= n; + } else { + --port_id; + } + } + + return -1; +} + GraphView::GraphView(GraphDef* graph) : graph_(graph) { for (int i = 0; i < graph_->node_size(); i++) { auto node = graph_->mutable_node(i); diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index ac260f85a0..ec946ca3b5 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -20,11 +20,21 @@ limitations under the License. #include <unordered_set> #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace grappler { +// Map a node/op's output port_id to arg_id. +// +// The port_id refers to the n-th tensor of the node, while the arg_id refers to +// the n-th arg of the op. These two can be different if an op's arg is a list +// of tensors. +// +// We return -1 for any invalid port_id (i.e., no corresponding arg_id). +int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); + // A utility class to simplify the traversal of a GraphDef. class GraphView { public: diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 958eb921fb..30512d9d47 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -25,6 +25,60 @@ namespace { class GraphViewTest : public ::testing::Test {}; +TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); + ops::ShapeN b(s.WithOpName("b"), {a, a, a}); + + GraphDef graph_def; + TF_CHECK_OK(s.ToGraphDef(&graph_def)); + GraphView graph_view(&graph_def); + + const NodeDef& a_node_def = *graph_view.GetNode("a"); + const NodeDef& b_node_def = *graph_view.GetNode("b"); + + const OpDef* a_op_def = nullptr; + const OpDef* b_op_def = nullptr; + EXPECT_TRUE( + OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok()); + EXPECT_TRUE( + OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); + + EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1)); + + EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0)); + EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1)); + EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4)); +} + +TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { + for (int num_splits : {1, 2}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10}); + ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits); + + GraphDef graph_def; + TF_CHECK_OK(s.ToGraphDef(&graph_def)); + GraphView graph_view(&graph_def); + + const NodeDef& b_node_def = *graph_view.GetNode("b"); + const OpDef* b_op_def = nullptr; + EXPECT_TRUE( + OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); + + for (int port_id = 0; port_id <= num_splits * 3; ++port_id) { + int arg_id = -1; + if (port_id < num_splits * 3) { + arg_id = port_id / num_splits; + } + EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id)); + } + } +} + TEST_F(GraphViewTest, BasicGraph) { TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"}); GrapplerItem item; diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 261dee4382..960d1addb3 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -518,6 +518,7 @@ cc_library( ":loop_optimizer", ":memory_optimizer", ":model_pruner", + ":pin_to_host_optimizer", ":remapper", ":scoped_allocator_optimizer", ":shape_optimizer", @@ -883,3 +884,41 @@ tf_cc_test( "//tensorflow/core/grappler/utils:grappler_test", ], ) + +cc_library( + name = "pin_to_host_optimizer", + srcs = ["pin_to_host_optimizer.cc"], + hdrs = [ + "pin_to_host_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:frame", + "//tensorflow/core/grappler/utils:symbolic_shapes", + "//tensorflow/core/grappler/utils:topological_sort", + ], +) + +tf_cuda_cc_test( + name = "pin_to_host_optimizer_test", + srcs = ["pin_to_host_optimizer_test.cc"], + deps = [ + ":pin_to_host_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 4b0cbfaa82..3da7a72e80 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/loop_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h" #include "tensorflow/core/grappler/optimizers/remapper.h" #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h" #include "tensorflow/core/grappler/optimizers/shape_optimizer.h" @@ -105,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer( MK_OPT("scoped_allocator", new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); + MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); return std::unique_ptr<GraphOptimizer>(); } @@ -133,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers( if (cfg_.remapping() != RewriterConfig::OFF) { optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping())); } + if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) { + optimizers->push_back(MakeUnique<PinToHostOptimizer>()); + } if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { optimizers->push_back( MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization())); @@ -468,6 +473,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) { cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || cfg.debug_stripper() == RewriterConfig::ON || cfg.scoped_allocator_optimization() == RewriterConfig::ON || + cfg.pin_to_host_optimization() == RewriterConfig::ON || !cfg.optimizers().empty() || !cfg.custom_optimizers().empty(); } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc new file mode 100644 index 0000000000..8a65cd3ec3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -0,0 +1,218 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { +namespace grappler { +namespace internal { + +// TODO(williamchan): Change this constant to be something smarter, maybe +// dynamically determined. +constexpr int64 kTensorMaxSize = 64; + +// Find KernelDef for `node`. +Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { + // Try find KernelDef for node.device, else GPU or CPU. + for (const DeviceType& device : + {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) { + Status s = FindKernelDef(device, node, kdef, nullptr); + if (s.ok()) { + return Status::OK(); + } + } + + return errors::NotFound("Could not find KernelDef for op: ", node.op()); +} + +// Check if all node's inputs are pinned to CPU memory. +bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) { + // Loop through all the inputs excluding the controlling nodes. + for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) { + // Check if (the fanin) op's device is on CPU. + if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) { + continue; + } + + // Check if (the fanin) op's output port is pinned to HostMemory. + const OpDef* fanin_odef = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef); + if (!s.ok()) { + LOG(INFO) << "Could not find OpDef for : " << fanin.node->op(); + return false; + } + + const int output_arg_id = + OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id); + if (output_arg_id < 0) { + LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n" + << node.DebugString() << "\n" + << fanin_odef->DebugString(); + return false; + } + + const KernelDef* fanin_kdef = nullptr; + s = TryFindKernelDef(*fanin.node, &fanin_kdef); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op(); + return false; + } + + bool fanin_pinned = false; + for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) { + if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) { + fanin_pinned = true; + break; + } + } + + if (!fanin_pinned) { + return false; + } + } + + return true; +} + +bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { + // Check if Tensor is integer and small size. + + // Check type to be int32 or int64. + if (prop.dtype() != DataType::DT_INT32 && + prop.dtype() != DataType::DT_INT64) { + return false; + } + + // Check size known and small. + const int64 size = NumCoefficients(prop.shape()); + if (size < 0 || size > kTensorMaxSize) { + return false; + } + + return true; +} + +bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties, + const NodeDef& node) { + for (const auto& prop : properties.GetInputProperties(node.name())) { + if (!IsTensorIntegerAndSmall(prop)) { + return false; + } + } + + for (const auto& prop : properties.GetOutputProperties(node.name())) { + if (!IsTensorIntegerAndSmall(prop)) { + return false; + } + } + return true; +} + +string TryFindHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, const string& device) { + // Force this node onto the CPU. + if (device.empty() && has_device_cpu) { + return "/device:CPU:0"; + } else if (str_util::StrContains(device, DEVICE_GPU)) { + // Sometimes the cluster can have: + // devices = {"/device:CPU:0", "/device:XLA_GPU:0"} + // and we need to handle them properly. + for (const auto& device_match : + {std::pair<string, string>("GPU", "CPU:0"), + std::pair<string, string>("/device", "/device:CPU:0")}) { + const string device_host = + strings::StrCat(device.substr(0, device.rfind(device_match.first)), + device_match.second); + if (devices.find(device_host) != devices.end()) { + return device_host; + } + } + } + + // We couldn't find an appropriate Host device, return original device. + return device; +} +} // end namespace internal + +Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + + GraphProperties properties(item); + bool has_properties = false; + GraphView graph(optimized_graph); + + gtl::FlatSet<string> devices; + if (cluster) { + const std::vector<string> device_names = cluster->GetDeviceNames(); + devices.insert(device_names.begin(), device_names.end()); + } else { + devices = {"/device:CPU:0"}; + } + + const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end(); + + // Topologically sort the graph, so that we traverse the nodes in order. This + // will help us discover producer->consumer chains of Host ops. + TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph)); + for (auto& node : *optimized_graph->mutable_node()) { + // Check if node already on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + continue; + } + + // Check the node can be run on CPU. + Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr); + if (!s.ok()) { + continue; + } + + // Check all input's are pinned to CPU. + if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) { + continue; + } + + if (!has_properties) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + has_properties = true; + } + + // Check all inputs and outputs are integers and small. + if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) { + continue; + } + + // Try and swap the device to Host. + node.set_device( + internal::TryFindHostDevice(devices, has_device_cpu, node.device())); + } + return Status::OK(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h new file mode 100644 index 0000000000..d557a03463 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ + +#include <unordered_set> +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { +namespace internal { +// Try and find an appropriate Host device in `devices` given `device`. +string TryFindHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, const string& device); +} // end namespace internal + +// Optimize TensorFlow ops that should be swapped into the CPU to avoid +// excessive cpu<->gpu memcpy/sync. +// +// TODO(williamchan): The current heuristic will swap any small integer Const to +// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of +// gpu->gpu->gpu may have been better/faster. We should probably fix this. +class PinToHostOptimizer : public GraphOptimizer { + public: + PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {} + explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + + ~PinToHostOptimizer() override {} + + string name() const override { return "pin_to_host_optimizer"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + RewriterConfig::Toggle opt_level_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc new file mode 100644 index 0000000000..339ddfd1b5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class PinToHostOptimizerTest : public GrapplerTest {}; + +TEST_F(PinToHostOptimizerTest, TryFindHostDevice) { + gtl::FlatSet<string> devices = {}; + EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC")); + + devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"), + "/device:CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"), + "/device:CPU:0"); + + devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), + "/device:XLA_CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), + "/device:XLA_CPU:0"); + + devices = {"/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), + "/device:XLA_GPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), + "/device:XLA_GPU:*"); +} + +TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024}); + Output c = ops::Shape(s.WithOpName("c"), a); + Output d = ops::Const(s.WithOpName("d"), 0, {1}); + Output e = ops::ReduceProd(s.WithOpName("e"), c, d); + + GrapplerItem item; + item.fetch = {"a", "c", "d", "e"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + auto tensors = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); ++i) { + test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]); + } + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "a" || node.name() == "c") { + EXPECT_TRUE(node.device().empty()); + } else if (node.name() == "d" || node.name() == "e") { + EXPECT_EQ(node.device(), "/device:CPU:0"); + } + ++found; + } + EXPECT_EQ(found, 4); +} + +TEST_F(PinToHostOptimizerTest, TopologicalSort) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024}); + Output c = ops::Shape(s.WithOpName("c"), a); + Output d = ops::Const(s.WithOpName("d"), 0, {1}); + Output e = ops::ReduceProd(s.WithOpName("e"), c, d); + + GrapplerItem item; + item.fetch = {"a", "c", "d", "e"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + // Reverse the graph, and hence rely on the optimizer to sort it. + std::reverse(item.graph.mutable_node()->begin(), + item.graph.mutable_node()->end()); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + auto tensors = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); ++i) { + test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]); + } + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "a" || node.name() == "c") { + EXPECT_TRUE(node.device().empty()); + } else if (node.name() == "d" || node.name() == "e") { + EXPECT_EQ(node.device(), "/device:CPU:0"); + } + ++found; + } + EXPECT_EQ(found, 4); +} + +TEST_F(PinToHostOptimizerTest, PortIdToArgId) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3}); + ops::ShapeN b(s.WithOpName("b"), {a, a, a}); + + GrapplerItem item; + item.fetch = {"a", "b"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + auto tensors = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); ++i) { + test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]); + } + + int found = 0; + for (const NodeDef& node : output.node()) { + EXPECT_EQ(node.device(), "/device:CPU:0"); + ++found; + } + EXPECT_EQ(found, 2); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 910b0acaef..6266733f3e 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() { // optimizations interfering in the comparison. RewriterConfig* cfg = options_.config.mutable_graph_options()->mutable_rewrite_options(); - cfg->set_constant_folding(RewriterConfig::OFF); + // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned + // off. cfg->set_arithmetic_optimization(RewriterConfig::OFF); + cfg->set_constant_folding(RewriterConfig::OFF); + cfg->set_debug_stripper(RewriterConfig::OFF); cfg->set_dependency_optimization(RewriterConfig::OFF); - cfg->set_loop_optimization(RewriterConfig::OFF); cfg->set_function_optimization(RewriterConfig::OFF); cfg->set_layout_optimizer(RewriterConfig::OFF); - cfg->set_debug_stripper(RewriterConfig::OFF); + cfg->set_loop_optimization(RewriterConfig::OFF); + cfg->set_pin_to_host_optimization(RewriterConfig::OFF); } std::vector<Tensor> GrapplerTest::EvaluateNodes( |