diff options
author | Yao Zhang <yaozhang@google.com> | 2018-04-06 19:52:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-06 19:54:51 -0700 |
commit | 4ce9d3a577ba3d9a0c46f05534510c00028652e6 (patch) | |
tree | 96553f9461dd4ed72d4cc332dc6a757099fff82d | |
parent | 992d1ebaab7f234bc0b8f28c524236e3cea580ab (diff) |
Place data format op on host if input tensor is in host memory.
PiperOrigin-RevId: 191972759
-rw-r--r-- | tensorflow/core/grappler/optimizers/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 47 | ||||
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer_test.cc | 24 |
3 files changed, 70 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 122fd48584..e4bc030885 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -480,6 +480,7 @@ tf_cuda_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:virtual_placer", ], ) diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 308eecd420..561226f945 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -17,9 +17,13 @@ limitations under the License. #include <unordered_set> #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/memory_types.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -363,6 +367,28 @@ std::vector<int> DataInputPos(const NodeDef& node) { return {}; } +bool IsHostMemory(const NodeDef& node, int output_port) { + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) { + DeviceType device_type(parsed_name.type); + Status s = FindKernelDef(device_type, node, nullptr, nullptr); + if (s.ok()) { + tensorflow::MemoryTypeVector in_mtypes; + tensorflow::MemoryTypeVector out_mtypes; + s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type, + node, &in_mtypes, &out_mtypes); + if (s.ok()) { + if (out_mtypes[output_port] == HOST_MEMORY) { + return true; + } + } + } else { + return true; + } + } + return false; +} + class GraphProcessor { public: GraphProcessor(const GraphProperties& graph_properties, @@ -883,6 +909,23 @@ class NodeProcessor : public GraphProcessor { list->set_i(3, w); } + string MaybeGetHostDevice(const string& input_name) const { + string device = node_->device(); + DeviceNameUtils::ParsedName parsed_name; + if (DeviceNameUtils::ParseFullName(device, &parsed_name)) { + if (parsed_name.type != "CPU") { + NodeDef* input = node_map_->GetNode(input_name); + int port; + ParseNodeName(input_name, &port); + if (IsHostMemory(*input, port)) { + parsed_name.type = "CPU"; + device = DeviceNameUtils::ParsedNameToString(parsed_name); + } + } + } + return device; + } + NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name, const string& op, DataType dtype, bool nhwc_to_nchw) { @@ -890,7 +933,9 @@ class NodeProcessor : public GraphProcessor { added_node->set_name(name); added_node->set_op(op); node_map_->AddNode(added_node->name(), added_node); - added_node->set_device(node_->device()); + // The inputs of a DataFormat op could be in host memory for ops such as + // Reshape. + added_node->set_device(MaybeGetHostDevice(input_name)); AttrValue attr_data_type; attr_data_type.set_type(dtype); added_node->mutable_attr()->insert({"T", attr_data_type}); diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc index 1c912fcaa2..260347b0e8 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/grappler/costs/virtual_placer.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -158,7 +159,7 @@ class LayoutOptimizerTest : public ::testing::Test { return output.x_backprop; } - std::unique_ptr<VirtualCluster> virtual_cluster_; + std::unique_ptr<Cluster> virtual_cluster_; }; TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) { @@ -1130,6 +1131,27 @@ TEST_F(LayoutOptimizerTest, LoopNoLiveLock) { EXPECT_EQ(mul_node->input(0), "Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer"); } + +TEST_F(LayoutOptimizerTest, DevicePlacement) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto conv = SimpleConv2D(&s, 4, 2, "VALID"); + auto shape = ops::Shape(s.WithOpName("s"), conv); + auto i = ops::Identity(s.WithOpName("i"), shape); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + VirtualPlacer virtual_placer(virtual_cluster_.get()); + for (auto& node : *item.graph.mutable_node()) { + string device = virtual_placer.get_canonical_device_name(node); + node.set_device(device); + } + LayoutOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output); + NodeMap node_map(&output); + auto vec_permute = + node_map.GetNode("s-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer"); + EXPECT_EQ(vec_permute->device(), "/device:CPU:0"); +} } // namespace } // namespace grappler } // namespace tensorflow |