aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc66
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py37
2 files changed, 96 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index 43dd9e193a..aef762b196 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -1551,13 +1551,13 @@ class SliceProcessor : public AgnosticNodeProcessor {
: AgnosticNodeProcessor(opt_cxt) {
// Skip the first input, which is the data to be sliced.
start_ = 1;
- // For StridedSlice, the last param is at index 3. Note that we can't use
- // node_->input_size() here because there could be control inputs.
- end_ = IsSlice(*node_) ? 2 : 3;
+ // Note that we can't use node_->input_size() here because there
+ // could be control inputs.
+ end_ = 2;
}
protected:
- Status CustomizedProcessing() override {
+ Status ProcessInputs() {
for (int i = start_; i <= end_; i++) {
DataType dtype = node_->attr().at("Index").type();
TF_RETURN_IF_ERROR(
@@ -1565,14 +1565,64 @@ class SliceProcessor : public AgnosticNodeProcessor {
}
return Status::OK();
}
+
+ Status CustomizedProcessing() override { return ProcessInputs(); }
+
int start_;
int end_;
};
-class StridedSliceGradProcessor : public SliceProcessor {
+class StridedSliceProcessor : public SliceProcessor {
public:
- explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt)
+ explicit StridedSliceProcessor(const OptimizeContext& opt_cxt)
: SliceProcessor(opt_cxt) {
+ start_ = 1;
+ end_ = 3;
+ }
+
+ protected:
+ bool ShouldProcess() const override {
+ return AgnosticNodeProcessor::ShouldProcess() && IsOnlyBeginEndMask();
+ }
+
+ Status CustomizedProcessing() override {
+ TF_RETURN_IF_ERROR(UpdateMask("begin_mask"));
+ TF_RETURN_IF_ERROR(UpdateMask("end_mask"));
+ TF_RETURN_IF_ERROR(ProcessInputs());
+ return Status::OK();
+ }
+
+ private:
+ bool IsMaskZero(const string& mask) const {
+ return node_->attr().at(mask).i() == 0;
+ }
+
+ bool IsOnlyBeginEndMask() const {
+ return IsMaskZero("ellipsis_mask") && IsMaskZero("new_axis_mask") &&
+ IsMaskZero("shrink_axis_mask");
+ }
+
+ Status UpdateMask(const string& mask) {
+ int i = node_->attr().at(mask).i();
+ if (i < 0 || i > 15) {
+ return errors::InvalidArgument("invalid mask value: ", i);
+ }
+ if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK();
+ if (i == 2 || i == 3) i += 2;
+ if (i == 4 || i == 5) i += 4;
+ if (i == 6 || i == 7) i += 6;
+ if (i == 8 || i == 9) i -= 6;
+ if (i == 10 || i == 11) i -= 4;
+ if (i == 12 || i == 13) i -= 2;
+ node_->mutable_attr()->at(mask).set_i(i);
+ return Status::OK();
+ }
+};
+
+class StridedSliceGradProcessor : public StridedSliceProcessor {
+ public:
+ explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt)
+ : StridedSliceProcessor(opt_cxt) {
start_ = 0;
end_ = 3;
}
@@ -1794,8 +1844,10 @@ class DataLayoutOptimizer : GraphProcessor {
node_processor.reset(new PadProcessor(opt_cxt));
} else if (IsReverseV2(*node)) {
node_processor.reset(new ReverseProcessor(opt_cxt));
- } else if (IsSlice(*node) || IsStridedSlice(*node)) {
+ } else if (IsSlice(*node)) {
node_processor.reset(new SliceProcessor(opt_cxt));
+ } else if (IsStridedSlice(*node)) {
+ node_processor.reset(new StridedSliceProcessor(opt_cxt));
} else if (IsShape(*node) || IsShapeN(*node)) {
node_processor.reset(new ShapeProcessor(opt_cxt));
} else if (IsSplit(*node)) {
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 1de0c7df9f..f2985d8089 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -750,6 +750,43 @@ class LayoutOptimizerTest(test.TestCase):
self.assertIn('LayoutOptimizer-StridedSlice-StridedSlice/strides', nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+ def testStridedSliceWithMask(self):
+ if test.is_gpu_available(cuda_only=True):
+ random_seed.set_random_seed(0)
+ x = random_ops.truncated_normal([1, 784], seed=0)
+ conv = _two_layer_model(x)
+ # This will generate a StridedSlice op with begin mask and end mask.
+ s = conv[:, :, 1:-1, :]
+ output = array_ops.identity(s)
+
+ with session.Session() as sess:
+ output_val_ref = sess.run(output)
+
+ with session.Session(config=_get_config()) as sess:
+ metadata = config_pb2.RunMetadata()
+ output_val = sess.run(output, run_metadata=metadata)
+
+ nodes = []
+ num_transposes = 0
+ for node in metadata.cost_graph.node:
+ if node.name.startswith('LayoutOptimizerTranspose'):
+ num_transposes += 1
+ nodes.append(node.name)
+
+ # Four transposes were initially added in the Expand phase of
+ # LayoutOptimizer; two of them are cancelled out in the Collapse phase.
+ expected_num_transposes = 2
+ self.assertEqual(expected_num_transposes, num_transposes)
+ self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes)
+ self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-strided_slice-0-0',
+ nodes)
+ self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack', nodes)
+ self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack_1',
+ nodes)
+ self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack_2',
+ nodes)
+ self.assertAllClose(output_val_ref, output_val, atol=1e-3)
+
def testStridedSliceGradWithNonConstAxis(self):
if test.is_gpu_available(cuda_only=True):
random_seed.set_random_seed(0)