diff options
-rw-r--r-- | tensorflow/core/grappler/optimizers/layout_optimizer.cc | 66 | ||||
-rw-r--r-- | tensorflow/python/grappler/layout_optimizer_test.py | 37 |
2 files changed, 96 insertions, 7 deletions
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 43dd9e193a..aef762b196 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1551,13 +1551,13 @@ class SliceProcessor : public AgnosticNodeProcessor { : AgnosticNodeProcessor(opt_cxt) { // Skip the first input, which is the data to be sliced. start_ = 1; - // For StridedSlice, the last param is at index 3. Note that we can't use - // node_->input_size() here because there could be control inputs. - end_ = IsSlice(*node_) ? 2 : 3; + // Note that we can't use node_->input_size() here because there + // could be control inputs. + end_ = 2; } protected: - Status CustomizedProcessing() override { + Status ProcessInputs() { for (int i = start_; i <= end_; i++) { DataType dtype = node_->attr().at("Index").type(); TF_RETURN_IF_ERROR( @@ -1565,14 +1565,64 @@ class SliceProcessor : public AgnosticNodeProcessor { } return Status::OK(); } + + Status CustomizedProcessing() override { return ProcessInputs(); } + int start_; int end_; }; -class StridedSliceGradProcessor : public SliceProcessor { +class StridedSliceProcessor : public SliceProcessor { public: - explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt) + explicit StridedSliceProcessor(const OptimizeContext& opt_cxt) : SliceProcessor(opt_cxt) { + start_ = 1; + end_ = 3; + } + + protected: + bool ShouldProcess() const override { + return AgnosticNodeProcessor::ShouldProcess() && IsOnlyBeginEndMask(); + } + + Status CustomizedProcessing() override { + TF_RETURN_IF_ERROR(UpdateMask("begin_mask")); + TF_RETURN_IF_ERROR(UpdateMask("end_mask")); + TF_RETURN_IF_ERROR(ProcessInputs()); + return Status::OK(); + } + + private: + bool IsMaskZero(const string& mask) const { + return node_->attr().at(mask).i() == 0; + } + + bool IsOnlyBeginEndMask() const { + return IsMaskZero("ellipsis_mask") && IsMaskZero("new_axis_mask") && + IsMaskZero("shrink_axis_mask"); + } + + Status UpdateMask(const string& mask) { + int i = node_->attr().at(mask).i(); + if (i < 0 || i > 15) { + return errors::InvalidArgument("invalid mask value: ", i); + } + if (i == 0 || i == 1 || i == 14 || i == 15) return Status::OK(); + if (i == 2 || i == 3) i += 2; + if (i == 4 || i == 5) i += 4; + if (i == 6 || i == 7) i += 6; + if (i == 8 || i == 9) i -= 6; + if (i == 10 || i == 11) i -= 4; + if (i == 12 || i == 13) i -= 2; + node_->mutable_attr()->at(mask).set_i(i); + return Status::OK(); + } +}; + +class StridedSliceGradProcessor : public StridedSliceProcessor { + public: + explicit StridedSliceGradProcessor(const OptimizeContext& opt_cxt) + : StridedSliceProcessor(opt_cxt) { start_ = 0; end_ = 3; } @@ -1794,8 +1844,10 @@ class DataLayoutOptimizer : GraphProcessor { node_processor.reset(new PadProcessor(opt_cxt)); } else if (IsReverseV2(*node)) { node_processor.reset(new ReverseProcessor(opt_cxt)); - } else if (IsSlice(*node) || IsStridedSlice(*node)) { + } else if (IsSlice(*node)) { node_processor.reset(new SliceProcessor(opt_cxt)); + } else if (IsStridedSlice(*node)) { + node_processor.reset(new StridedSliceProcessor(opt_cxt)); } else if (IsShape(*node) || IsShapeN(*node)) { node_processor.reset(new ShapeProcessor(opt_cxt)); } else if (IsSplit(*node)) { diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 1de0c7df9f..f2985d8089 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -750,6 +750,43 @@ class LayoutOptimizerTest(test.TestCase): self.assertIn('LayoutOptimizer-StridedSlice-StridedSlice/strides', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testStridedSliceWithMask(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + # This will generate a StridedSlice op with begin mask and end mask. + s = conv[:, :, 1:-1, :] + output = array_ops.identity(s) + + with session.Session() as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if node.name.startswith('LayoutOptimizerTranspose'): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-0', nodes) + self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-strided_slice-0-0', + nodes) + self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack', nodes) + self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack_1', + nodes) + self.assertIn('LayoutOptimizer-strided_slice-strided_slice/stack_2', + nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testStridedSliceGradWithNonConstAxis(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) |