diff options
author | Yao Zhang <yaozhang@google.com> | 2017-12-22 09:44:32 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-22 09:48:15 -0800 |
commit | bb1b0018ca8659f0edb9a9bff791a27c52e54e02 (patch) | |
tree | 312ba687024aeae7ee75779754296c7f38226c2a | |
parent | bb47c93c1aaa7691a6ad4e540fcbf0c6d337754e (diff) |
No need to and don't add layout transform to control nodes.
PiperOrigin-RevId: 179934839
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 157 | ||||
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 36 |
3 files changed, 108 insertions, 86 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index af4f83f936..c9e5f842be 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -318,6 +318,49 @@ bool IsBinaryOp(const NodeDef& node) { return is_binary; } +std::vector<int> NonControlInputs(const NodeDef& node) { + std::vector<int> pos; + for (int i = 0; i < node.input_size(); i++) { + if (!IsControlInput(node.input(i))) { + pos.push_back(i); + } + } + return pos; +} + +std::vector<int> DataInputPosConcat(const NodeDef& node) { + int n = node.attr().at("N").i(); + std::vector<int> input_pos; + int start = (IsConcatV1(node)) ? 1 : 0; + int end = start + n; + for (int i = start; i < end; i++) { + input_pos.push_back(i); + } + return input_pos; +} + +std::vector<int> DataInputPos(const NodeDef& node) { + if (IsSplit(node)) { + return {1}; + } + if (IsBinaryOp(node) || IsUnaryGrad(node)) { + return {0, 1}; + } + if (IsBetainc(node) || IsSelect(node)) { + return {0, 1, 2}; + } + if (IsShapeN(node) || IsIdentityN(node) || IsAddN(node)) { + return NonControlInputs(node); + } + if (IsConcat(node)) { + return DataInputPosConcat(node); + } + if (node.input_size() > 0 && !IsControlInput(node.input(0))) { + return {0}; + } + return {}; +} + class GraphProcessor { public: GraphProcessor(const VirtualPlacer& virtual_placer, @@ -1142,30 +1185,6 @@ class AgnosticNodeProcessor : public NodeProcessor { } bool IsNodeAfterNCHWToNHWC() const { return IsNodeAfterNCHWToNHWC(*node_); } - - private: - std::vector<int> DataInputPos(const NodeDef& node) const { - if (IsSplit(node) || IsConcatV1(node)) { - return {1}; - } - if (IsBinaryOp(node) || IsUnaryGrad(node)) { - return {0, 1}; - } - if (IsBetainc(node) || IsSelect(node)) { - return {0, 1, 2}; - } - if (IsShapeN(node) || IsIdentityN(node)) { - std::vector<int> pos; - for (int i = 0; i < node.input_size(); i++) { - pos.push_back(i); - } - return pos; - } - if (node.input_size() > 0 && !IsControlInput(node.input(0))) { - return {0}; - } - return {}; - } }; class AddNProcessor : public AgnosticNodeProcessor { @@ -1175,12 +1194,7 @@ class AddNProcessor : public AgnosticNodeProcessor { protected: std::vector<int> GetInputPos() const override { - std::vector<int> input_pos; - input_pos.reserve(node_->input_size()); - for (int i = 0; i < node_->input_size(); i++) { - input_pos.push_back(i); - } - return input_pos; + return NonControlInputs(*node_); } }; @@ -1325,24 +1339,20 @@ class ConcatProcessor : public AgnosticNodeProcessor { explicit ConcatProcessor(const OptimizeContext& opt_cxt) : AgnosticNodeProcessor(opt_cxt) { // For Concat, the concat axis is the first input; for ConcatV2, - // the last input. - axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : (node_->input_size() - 1); + // the last input. Note that if with control inputs, the number of inputs + // is larger than the integer attribute N. + int n = node_->attr().at("N").i(); + axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : n; } protected: std::vector<int> GetInputPos() const override { - std::vector<int> input_pos; - int start = (IsConcatV1(*node_)) ? 1 : 0; - int end = - (IsConcatV1(*node_)) ? node_->input_size() : (node_->input_size() - 1); - for (int i = start; i < end; i++) { - input_pos.push_back(i); - } - return input_pos; + return DataInputPosConcat(*node_); } Status CustomizedProcessing() override { - DataType dtype = node_->attr().at("Tidx").type(); + DataType dtype = + (IsConcatV1(*node_)) ? DT_INT32 : node_->attr().at("Tidx").type(); TF_RETURN_IF_ERROR( UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap", dtype)); return Status::OK(); @@ -1384,10 +1394,13 @@ class IdentityNProcessor : public AgnosticNodeProcessor { auto input = node_map_->GetNode(node_->input(i)); int port; ParseNodeName(node_->input(i), &port); - if (IsPortDimsFour(*input, port) && - (IsNodeAfterNCHWToNHWC(*input) || - IsTransposeNCHWToNHWC(input->name()))) { - input_pos.push_back(i); + // Skip control input. + if (port != -1) { + if (IsPortDimsFour(*input, port) && + (IsNodeAfterNCHWToNHWC(*input) || + IsTransposeNCHWToNHWC(input->name()))) { + input_pos.push_back(i); + } } } return input_pos; @@ -1402,6 +1415,19 @@ class IdentityNProcessor : public AgnosticNodeProcessor { } }; +class ShapeProcessor : public IdentityNProcessor { + public: + explicit ShapeProcessor(const OptimizeContext& opt_cxt) + : IdentityNProcessor(opt_cxt) {} + + protected: + Status AddLayoutTransposeToOutputs() override { return Status::OK(); } + + Status CustomizedProcessing() override { + return AddTransformToOutputs("DataFormatVecPermute"); + } +}; + class MergeProcessor : public AgnosticNodeProcessor { public: explicit MergeProcessor(const OptimizeContext& opt_cxt) @@ -1522,47 +1548,6 @@ class UnaryGradProcessor : public AgnosticNodeProcessor { std::vector<int> GetInputPos() const override { return {0, 1}; } }; -class ShapeProcessor : public AgnosticNodeProcessor { - public: - explicit ShapeProcessor(const OptimizeContext& opt_cxt) - : AgnosticNodeProcessor(opt_cxt) {} - - protected: - bool ShouldProcess() const override { - return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && - IsOnGPU(); - } - - std::vector<int> GetInputPos() const override { - std::vector<int> input_pos; - for (int i = 0; i < node_->input_size(); i++) { - auto input = node_map_->GetNode(node_->input(i)); - int port; - ParseNodeName(node_->input(i), &port); - if (IsPortDimsFour(*input, port) && - (IsNodeAfterNCHWToNHWC(*input) || - IsTransposeNCHWToNHWC(input->name()))) { - input_pos.push_back(i); - } - } - return input_pos; - } - - std::set<int> GetOutputPos() const override { - std::set<int> output_pos{}; - for (const auto& input_pos : GetInputPos()) { - output_pos.insert(input_pos); - } - return output_pos; - } - - Status AddLayoutTransposeToOutputs() override { return Status::OK(); } - - Status CustomizedProcessing() override { - return AddTransformToOutputs("DataFormatVecPermute"); - } -}; - class SliceProcessor : public AgnosticNodeProcessor { public: explicit SliceProcessor(const OptimizeContext& opt_cxt) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 7c914c5b6a..b28f2c88c2 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4540,6 +4540,7 @@ cuda_py_test( ":nn", ":ops", ":random_ops", + ":state_ops", ":tf_cluster", ":tf_optimizer", ":training", diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 151913b591..bb4f6388fd 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -40,6 +40,7 @@ from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent @@ -350,6 +351,41 @@ class LayoutOptimizerTest(test.TestCase): self.assertIn('LayoutOptimizer-Pad-PaddingsConst', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testConcatWithControlDependency(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + axis = constant_op.constant(3) + var = variables.Variable(3) + assign = state_ops.assign(var, 6) + with ops.control_dependencies([assign]): + concat = array_ops.concat([conv, conv], axis) + output = array_ops.identity(concat) + + with session.Session() as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if node.name.startswith('LayoutOptimizerTranspose'): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes) + self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-concat-0-0', nodes) + self.assertIn('LayoutOptimizer-concat-Const_2', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testFill(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) |