aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2017-12-22 09:44:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-22 09:48:15 -0800
commitbb1b0018ca8659f0edb9a9bff791a27c52e54e02 (patch)
tree312ba687024aeae7ee75779754296c7f38226c2a
parentbb47c93c1aaa7691a6ad4e540fcbf0c6d337754e (diff)
No need to and don't add layout transform to control nodes.
PiperOrigin-RevId: 179934839
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc157
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py36
3 files changed, 108 insertions, 86 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index af4f83f936..c9e5f842be 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -318,6 +318,49 @@ bool IsBinaryOp(const NodeDef& node) {
return is_binary;
}
+std::vector<int> NonControlInputs(const NodeDef& node) {
+ std::vector<int> pos;
+ for (int i = 0; i < node.input_size(); i++) {
+ if (!IsControlInput(node.input(i))) {
+ pos.push_back(i);
+ }
+ }
+ return pos;
+}
+
+std::vector<int> DataInputPosConcat(const NodeDef& node) {
+ int n = node.attr().at("N").i();
+ std::vector<int> input_pos;
+ int start = (IsConcatV1(node)) ? 1 : 0;
+ int end = start + n;
+ for (int i = start; i < end; i++) {
+ input_pos.push_back(i);
+ }
+ return input_pos;
+}
+
+std::vector<int> DataInputPos(const NodeDef& node) {
+ if (IsSplit(node)) {
+ return {1};
+ }
+ if (IsBinaryOp(node) || IsUnaryGrad(node)) {
+ return {0, 1};
+ }
+ if (IsBetainc(node) || IsSelect(node)) {
+ return {0, 1, 2};
+ }
+ if (IsShapeN(node) || IsIdentityN(node) || IsAddN(node)) {
+ return NonControlInputs(node);
+ }
+ if (IsConcat(node)) {
+ return DataInputPosConcat(node);
+ }
+ if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
+ return {0};
+ }
+ return {};
+}
+
class GraphProcessor {
public:
GraphProcessor(const VirtualPlacer& virtual_placer,
@@ -1142,30 +1185,6 @@ class AgnosticNodeProcessor : public NodeProcessor {
}
bool IsNodeAfterNCHWToNHWC() const { return IsNodeAfterNCHWToNHWC(*node_); }
-
- private:
- std::vector<int> DataInputPos(const NodeDef& node) const {
- if (IsSplit(node) || IsConcatV1(node)) {
- return {1};
- }
- if (IsBinaryOp(node) || IsUnaryGrad(node)) {
- return {0, 1};
- }
- if (IsBetainc(node) || IsSelect(node)) {
- return {0, 1, 2};
- }
- if (IsShapeN(node) || IsIdentityN(node)) {
- std::vector<int> pos;
- for (int i = 0; i < node.input_size(); i++) {
- pos.push_back(i);
- }
- return pos;
- }
- if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
- return {0};
- }
- return {};
- }
};
class AddNProcessor : public AgnosticNodeProcessor {
@@ -1175,12 +1194,7 @@ class AddNProcessor : public AgnosticNodeProcessor {
protected:
std::vector<int> GetInputPos() const override {
- std::vector<int> input_pos;
- input_pos.reserve(node_->input_size());
- for (int i = 0; i < node_->input_size(); i++) {
- input_pos.push_back(i);
- }
- return input_pos;
+ return NonControlInputs(*node_);
}
};
@@ -1325,24 +1339,20 @@ class ConcatProcessor : public AgnosticNodeProcessor {
explicit ConcatProcessor(const OptimizeContext& opt_cxt)
: AgnosticNodeProcessor(opt_cxt) {
// For Concat, the concat axis is the first input; for ConcatV2,
- // the last input.
- axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : (node_->input_size() - 1);
+ // the last input. Note that if with control inputs, the number of inputs
+ // is larger than the integer attribute N.
+ int n = node_->attr().at("N").i();
+ axis_node_pos_ = (IsConcatV1(*node_)) ? 0 : n;
}
protected:
std::vector<int> GetInputPos() const override {
- std::vector<int> input_pos;
- int start = (IsConcatV1(*node_)) ? 1 : 0;
- int end =
- (IsConcatV1(*node_)) ? node_->input_size() : (node_->input_size() - 1);
- for (int i = start; i < end; i++) {
- input_pos.push_back(i);
- }
- return input_pos;
+ return DataInputPosConcat(*node_);
}
Status CustomizedProcessing() override {
- DataType dtype = node_->attr().at("Tidx").type();
+ DataType dtype =
+ (IsConcatV1(*node_)) ? DT_INT32 : node_->attr().at("Tidx").type();
TF_RETURN_IF_ERROR(
UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap", dtype));
return Status::OK();
@@ -1384,10 +1394,13 @@ class IdentityNProcessor : public AgnosticNodeProcessor {
auto input = node_map_->GetNode(node_->input(i));
int port;
ParseNodeName(node_->input(i), &port);
- if (IsPortDimsFour(*input, port) &&
- (IsNodeAfterNCHWToNHWC(*input) ||
- IsTransposeNCHWToNHWC(input->name()))) {
- input_pos.push_back(i);
+ // Skip control input.
+ if (port != -1) {
+ if (IsPortDimsFour(*input, port) &&
+ (IsNodeAfterNCHWToNHWC(*input) ||
+ IsTransposeNCHWToNHWC(input->name()))) {
+ input_pos.push_back(i);
+ }
}
}
return input_pos;
@@ -1402,6 +1415,19 @@ class IdentityNProcessor : public AgnosticNodeProcessor {
}
};
+class ShapeProcessor : public IdentityNProcessor {
+ public:
+ explicit ShapeProcessor(const OptimizeContext& opt_cxt)
+ : IdentityNProcessor(opt_cxt) {}
+
+ protected:
+ Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
+
+ Status CustomizedProcessing() override {
+ return AddTransformToOutputs("DataFormatVecPermute");
+ }
+};
+
class MergeProcessor : public AgnosticNodeProcessor {
public:
explicit MergeProcessor(const OptimizeContext& opt_cxt)
@@ -1522,47 +1548,6 @@ class UnaryGradProcessor : public AgnosticNodeProcessor {
std::vector<int> GetInputPos() const override { return {0, 1}; }
};
-class ShapeProcessor : public AgnosticNodeProcessor {
- public:
- explicit ShapeProcessor(const OptimizeContext& opt_cxt)
- : AgnosticNodeProcessor(opt_cxt) {}
-
- protected:
- bool ShouldProcess() const override {
- return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
- IsOnGPU();
- }
-
- std::vector<int> GetInputPos() const override {
- std::vector<int> input_pos;
- for (int i = 0; i < node_->input_size(); i++) {
- auto input = node_map_->GetNode(node_->input(i));
- int port;
- ParseNodeName(node_->input(i), &port);
- if (IsPortDimsFour(*input, port) &&
- (IsNodeAfterNCHWToNHWC(*input) ||
- IsTransposeNCHWToNHWC(input->name()))) {
- input_pos.push_back(i);
- }
- }
- return input_pos;
- }
-
- std::set<int> GetOutputPos() const override {
- std::set<int> output_pos{};
- for (const auto& input_pos : GetInputPos()) {
- output_pos.insert(input_pos);
- }
- return output_pos;
- }
-
- Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
-
- Status CustomizedProcessing() override {
- return AddTransformToOutputs("DataFormatVecPermute");
- }
-};
-
class SliceProcessor : public AgnosticNodeProcessor {
public:
explicit SliceProcessor(const OptimizeContext& opt_cxt)
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 7c914c5b6a..b28f2c88c2 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -4540,6 +4540,7 @@ cuda_py_test(
":nn",
":ops",
":random_ops",
+ ":state_ops",
":tf_cluster",
":tf_optimizer",
":training",
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 151913b591..bb4f6388fd 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -40,6 +40,7 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
@@ -350,6 +351,41 @@ class LayoutOptimizerTest(test.TestCase):
self.assertIn('LayoutOptimizer-Pad-PaddingsConst', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+ def testConcatWithControlDependency(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ axis = constant_op.constant(3)
+ var = variables.Variable(3)
+ assign = state_ops.assign(var, 6)
+ with ops.control_dependencies([assign]):
+ concat = array_ops.concat([conv, conv], axis)
+ output = array_ops.identity(concat)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if node.name.startswith('LayoutOptimizerTranspose'):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Four transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 2
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+ self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-concat-0-0', nodes)
+ self.assertIn('LayoutOptimizer-concat-Const_2', nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
def testFill(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)