aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-04-06 19:52:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 19:54:51 -0700
commit4ce9d3a577ba3d9a0c46f05534510c00028652e6 (patch)
tree96553f9461dd4ed72d4cc332dc6a757099fff82d
parent992d1ebaab7f234bc0b8f28c524236e3cea580ab (diff)
Place data format op on host if input tensor is in host memory.
PiperOrigin-RevId: 191972759
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD1
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc24
3 files changed, 70 insertions, 2 deletions
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 122fd48584..e4bc030885 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -480,6 +480,7 @@ tf_cuda_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:virtual_cluster",
+ "//tensorflow/core/grappler/costs:virtual_placer",
],
)
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 308eecd420..561226f945 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -17,9 +17,13 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/devices.h"
#include "tensorflow/core/grappler/grappler_item.h"
@@ -363,6 +367,28 @@ std::vector<int> DataInputPos(const NodeDef& node) {
return {};
}
+bool IsHostMemory(const NodeDef& node, int output_port) {
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(node.device(), &parsed_name)) {
+ DeviceType device_type(parsed_name.type);
+ Status s = FindKernelDef(device_type, node, nullptr, nullptr);
+ if (s.ok()) {
+ tensorflow::MemoryTypeVector in_mtypes;
+ tensorflow::MemoryTypeVector out_mtypes;
+ s = tensorflow::MemoryTypesForNode(OpRegistry::Global(), device_type,
+ node, &in_mtypes, &out_mtypes);
+ if (s.ok()) {
+ if (out_mtypes[output_port] == HOST_MEMORY) {
+ return true;
+ }
+ }
+ } else {
+ return true;
+ }
+ }
+ return false;
+}
+
class GraphProcessor {
public:
GraphProcessor(const GraphProperties& graph_properties,
@@ -883,6 +909,23 @@ class NodeProcessor : public GraphProcessor {
list->set_i(3, w);
}
+ string MaybeGetHostDevice(const string& input_name) const {
+ string device = node_->device();
+ DeviceNameUtils::ParsedName parsed_name;
+ if (DeviceNameUtils::ParseFullName(device, &parsed_name)) {
+ if (parsed_name.type != "CPU") {
+ NodeDef* input = node_map_->GetNode(input_name);
+ int port;
+ ParseNodeName(input_name, &port);
+ if (IsHostMemory(*input, port)) {
+ parsed_name.type = "CPU";
+ device = DeviceNameUtils::ParsedNameToString(parsed_name);
+ }
+ }
+ }
+ return device;
+ }
+
NodeDef* AddNodeDataFormatOp(const string& name, const string& input_name,
const string& op, DataType dtype,
bool nhwc_to_nchw) {
@@ -890,7 +933,9 @@ class NodeProcessor : public GraphProcessor {
added_node->set_name(name);
added_node->set_op(op);
node_map_->AddNode(added_node->name(), added_node);
- added_node->set_device(node_->device());
+ // The inputs of a DataFormat op could be in host memory for ops such as
+ // Reshape.
+ added_node->set_device(MaybeGetHostDevice(input_name));
AttrValue attr_data_type;
attr_data_type.set_type(dtype);
added_node->mutable_attr()->insert({"T", attr_data_type});
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index 1c912fcaa2..260347b0e8 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/costs/virtual_placer.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -158,7 +159,7 @@ class LayoutOptimizerTest : public ::testing::Test {
return output.x_backprop;
}
- std::unique_ptr<VirtualCluster> virtual_cluster_;
+ std::unique_ptr<Cluster> virtual_cluster_;
};
TEST_F(LayoutOptimizerTest, Conv2DBackpropInput) {
@@ -1130,6 +1131,27 @@ TEST_F(LayoutOptimizerTest, LoopNoLiveLock) {
EXPECT_EQ(mul_node->input(0),
"Conv2D-0-0-TransposeNCHWToNHWC-LayoutOptimizer");
}
+
+TEST_F(LayoutOptimizerTest, DevicePlacement) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 4, 2, "VALID");
+ auto shape = ops::Shape(s.WithOpName("s"), conv);
+ auto i = ops::Identity(s.WithOpName("i"), shape);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ VirtualPlacer virtual_placer(virtual_cluster_.get());
+ for (auto& node : *item.graph.mutable_node()) {
+ string device = virtual_placer.get_canonical_device_name(node);
+ node.set_device(device);
+ }
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto vec_permute =
+ node_map.GetNode("s-0-0-VecPermuteNCHWToNHWC-LayoutOptimizer");
+ EXPECT_EQ(vec_permute->device(), "/device:CPU:0");
+}
} // namespace
} // namespace grappler
} // namespace tensorflow