diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc | 211 |
1 files changed, 211 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 77f3c64c65..d091b26b65 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -288,6 +288,12 @@ class ArithmeticOptimizerTest : public GrapplerTest { DisableAllStages(optimizer); optimizer->options_.unary_ops_composition = true; } + + void EnableOnlyRemoveStackStridedSliceSameAxis( + ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.remove_stack_strided_slice_same_axis = true; + } }; TEST_F(ArithmeticOptimizerTest, NoOp) { @@ -3364,5 +3370,210 @@ TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); } +TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto a_in = + ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); + auto b_in = + ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2}); + auto c_in = + ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2}); + auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in, + PartialTensorShape({-1, -1})); + auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in, + PartialTensorShape({-1, -1})); + auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in, + PartialTensorShape({-1, -1})); + // stacked = tf.stack((a, b, c), axis=1). + // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1) + auto stacked = + ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output}, + ops::Stack::Axis(1)); + auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1}); + auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1}); + auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1}); + auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3}); + auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3}); + auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3}); + auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3}); + auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3}); + auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3}); + auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3}); + auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3}); + + // stacked[:, 0] + using SS = ops::StridedSlice; + auto pa_slice = ops::Identity( + s.WithOpName("pa_slice_out"), + SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0b0010))); // 2 + + // stacked[:, 1] + auto pb_slice = ops::Identity( + s.WithOpName("pb_slice_out"), + SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0b0010))); // 2 + + // stacked[:, 2] + auto pc_slice = ops::Identity( + s.WithOpName("pc_slice_out"), + SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0b0010))); // 2 + + // stacked[:, 0:1, :] + auto pa_slice_01 = ops::Identity( + s.WithOpName("pa_slice_01_out"), + SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0))); + + // stacked[:, :1, :] + auto pa_slice_to1 = ops::Identity( + s.WithOpName("pa_slice_to1_out"), + SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides, + SS::BeginMask(0b0111) // 7 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0))); + + // stacked[:, 1:2, :] + auto pb_slice_12 = ops::Identity( + s.WithOpName("pb_slice_12_out"), + SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0101) // 5 + .NewAxisMask(0) + .ShrinkAxisMask(0))); + + // stacked[:, 2:, :]. + auto pc_slice_2to = ops::Identity( + s.WithOpName("pc_slice_2to_out"), + SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides, + SS::BeginMask(0b0101) // 5 + .EllipsisMask(0) + .EndMask(0b0111) // 7 + .NewAxisMask(0) + .ShrinkAxisMask(0))); + + GrapplerItem item; + item.fetch = {"a", + "b", + "c", + "pa_slice_out", + "pb_slice_out", + "pc_slice_out", + "expanded_a", + "expanded_b", + "expanded_c", + "pa_slice_01_out", + "pa_slice_to1_out", + "pb_slice_12_out", + "pc_slice_2to_out"}; + enum FetchItem { + fA, + fB, + fC, + fASliceOut, + fBSliceOut, + fCSliceOut, + fExpandedA, + fExpandedB, + fExpandedC, + fASlice01Out, + fASliceTo1Out, + fBSlice12Out, + fCSlice2ToOut, + }; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + // stacked[:, 0, :] == a. + test::ExpectTensorEqual<float>(tensors_expected[fA], + tensors_expected[fASliceOut]); + // stacked[:, 1, :] == b. + test::ExpectTensorEqual<float>(tensors_expected[fB], + tensors_expected[fBSliceOut]); + // stacked[:, 2, :] == c. + test::ExpectTensorEqual<float>(tensors_expected[fC], + tensors_expected[fCSliceOut]); + + // stacked[:, 0:1, :] == expand_dims(a, 1). + test::ExpectTensorEqual<float>(tensors_expected[fExpandedA], + tensors_expected[fASlice01Out]); + + // stacked[:, :1, :] == expand_dims(a, 1). + test::ExpectTensorEqual<float>(tensors_expected[fExpandedA], + tensors_expected[fASliceTo1Out]); + + // stacked[:, 1:2, :] == expand_dims(b, 1). + test::ExpectTensorEqual<float>(tensors_expected[fExpandedB], + tensors_expected[fBSlice12Out]); + // stacked[:, 2:, :] == expand_dims(c, 1). + test::ExpectTensorEqual<float>(tensors_expected[fExpandedC], + tensors_expected[fCSlice2ToOut]); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyRemoveStackStridedSliceSameAxis(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + + for (const auto& node : output.node()) { + if (node.name() == "pa_slice_out") { + EXPECT_EQ(node.input(0), "a"); + } else if (node.name() == "pb_slice_out") { + EXPECT_EQ(node.input(0), "b"); + } else if (node.name() == "pc_slice_out") { + EXPECT_EQ(node.input(0), "c"); + } else if (str_util::EndsWith(node.name(), "_out")) { + EXPECT_EQ(strings::StrCat(node.input(0), "_out"), + strings::StrCat( + "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_", + node.name())); + } + } + + auto tensors = EvaluateNodes(output, item.fetch); + + // stacked[:, 0, :] == a. + test::ExpectTensorEqual<float>(tensors_expected[fA], tensors[fASliceOut]); + + // stacked[:, 1, :] == b. + test::ExpectTensorEqual<float>(tensors_expected[fB], tensors[fBSliceOut]); + // stacked[:, 2, :] == c. + test::ExpectTensorEqual<float>(tensors_expected[fC], tensors[fCSliceOut]); + + // stacked[:, 0:1, :] == expand_dims(a, 1). + test::ExpectTensorEqual<float>(tensors_expected[fExpandedA], + tensors[fASlice01Out]); + + // stacked[:, :1, :] == expand_dims(a, 1). + test::ExpectTensorEqual<float>(tensors_expected[fExpandedA], + tensors[fASliceTo1Out]); + + // stacked[:, 1:2, :] == expand_dims(b, 1). + test::ExpectTensorEqual<float>(tensors_expected[fExpandedB], + tensors[fBSlice12Out]); + // stacked[:, 2:, :] == expand_dims(c, 1). + test::ExpectTensorEqual<float>(tensors_expected[fExpandedC], + tensors[fCSlice2ToOut]); +} + } // namespace grappler } // namespace tensorflow |