aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc')
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc211
1 files changed, 211 insertions, 0 deletions
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 77f3c64c65..d091b26b65 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -288,6 +288,12 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.unary_ops_composition = true;
}
+
+ void EnableOnlyRemoveStackStridedSliceSameAxis(
+ ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.remove_stack_strided_slice_same_axis = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -3364,5 +3370,210 @@ TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
}
+TEST_F(ArithmeticOptimizerTest, RemoveStackStridedSliceSameAxis) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto a_in =
+ ops::Const(s.WithOpName("a_in"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});
+ auto b_in =
+ ops::Const(s.WithOpName("b_in"), {-1.0f, -2.0f, -3.0f, -4.0f}, {2, 2});
+ auto c_in =
+ ops::Const(s.WithOpName("c_in"), {5.0f, 6.0f, 7.0f, 8.0f}, {2, 2});
+ auto a = ops::PlaceholderWithDefault(s.WithOpName("a"), a_in,
+ PartialTensorShape({-1, -1}));
+ auto b = ops::PlaceholderWithDefault(s.WithOpName("b"), b_in,
+ PartialTensorShape({-1, -1}));
+ auto c = ops::PlaceholderWithDefault(s.WithOpName("c"), c_in,
+ PartialTensorShape({-1, -1}));
+ // stacked = tf.stack((a, b, c), axis=1).
+ // stacked.shape == [2, 3, 2] (a, b, c are stacked along new axis 1)
+ auto stacked =
+ ops::Stack(s.WithOpName("stacked"), {a.output, b.output, c.output},
+ ops::Stack::Axis(1));
+ auto expanded_a = ops::ExpandDims(s.WithOpName("expanded_a"), a, {1});
+ auto expanded_b = ops::ExpandDims(s.WithOpName("expanded_b"), b, {1});
+ auto expanded_c = ops::ExpandDims(s.WithOpName("expanded_c"), c, {1});
+ auto begin_a = ops::Const(s.WithOpName("begin_a"), {0, 0, 0}, {3});
+ auto end_a = ops::Const(s.WithOpName("end_a"), {0, 1, 0}, {3});
+ auto begin_b = ops::Const(s.WithOpName("begin_b"), {0, 1, 0}, {3});
+ auto end_b = ops::Const(s.WithOpName("end_b"), {0, 2, 0}, {3});
+ auto begin_c = ops::Const(s.WithOpName("begin_c"), {0, 2, 0}, {3});
+ auto end_c = ops::Const(s.WithOpName("end_c"), {0, 3, 0}, {3});
+ auto end_c_1to = ops::Const(s.WithOpName("begin_c_2to"), {0, 0, 0}, {3});
+ auto strides = ops::Const(s.WithOpName("strides"), {1, 1, 1}, {3});
+
+ // stacked[:, 0]
+ using SS = ops::StridedSlice;
+ auto pa_slice = ops::Identity(
+ s.WithOpName("pa_slice_out"),
+ SS(s.WithOpName("pa_slice"), stacked, begin_a, end_a, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0b0010))); // 2
+
+ // stacked[:, 1]
+ auto pb_slice = ops::Identity(
+ s.WithOpName("pb_slice_out"),
+ SS(s.WithOpName("pb_slice"), stacked, begin_b, end_b, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0b0010))); // 2
+
+ // stacked[:, 2]
+ auto pc_slice = ops::Identity(
+ s.WithOpName("pc_slice_out"),
+ SS(s.WithOpName("pc_slice"), stacked, begin_c, end_c, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0b0010))); // 2
+
+ // stacked[:, 0:1, :]
+ auto pa_slice_01 = ops::Identity(
+ s.WithOpName("pa_slice_01_out"),
+ SS(s.WithOpName("pa_slice_01"), stacked, begin_a, end_a, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0)));
+
+ // stacked[:, :1, :]
+ auto pa_slice_to1 = ops::Identity(
+ s.WithOpName("pa_slice_to1_out"),
+ SS(s.WithOpName("pa_slice_to1"), stacked, begin_a, end_a, strides,
+ SS::BeginMask(0b0111) // 7
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0)));
+
+ // stacked[:, 1:2, :]
+ auto pb_slice_12 = ops::Identity(
+ s.WithOpName("pb_slice_12_out"),
+ SS(s.WithOpName("pb_slice_12"), stacked, begin_b, end_b, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0101) // 5
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0)));
+
+ // stacked[:, 2:, :].
+ auto pc_slice_2to = ops::Identity(
+ s.WithOpName("pc_slice_2to_out"),
+ SS(s.WithOpName("pc_slice_2to"), stacked, begin_c, end_c_1to, strides,
+ SS::BeginMask(0b0101) // 5
+ .EllipsisMask(0)
+ .EndMask(0b0111) // 7
+ .NewAxisMask(0)
+ .ShrinkAxisMask(0)));
+
+ GrapplerItem item;
+ item.fetch = {"a",
+ "b",
+ "c",
+ "pa_slice_out",
+ "pb_slice_out",
+ "pc_slice_out",
+ "expanded_a",
+ "expanded_b",
+ "expanded_c",
+ "pa_slice_01_out",
+ "pa_slice_to1_out",
+ "pb_slice_12_out",
+ "pc_slice_2to_out"};
+ enum FetchItem {
+ fA,
+ fB,
+ fC,
+ fASliceOut,
+ fBSliceOut,
+ fCSliceOut,
+ fExpandedA,
+ fExpandedB,
+ fExpandedC,
+ fASlice01Out,
+ fASliceTo1Out,
+ fBSlice12Out,
+ fCSlice2ToOut,
+ };
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ // stacked[:, 0, :] == a.
+ test::ExpectTensorEqual<float>(tensors_expected[fA],
+ tensors_expected[fASliceOut]);
+ // stacked[:, 1, :] == b.
+ test::ExpectTensorEqual<float>(tensors_expected[fB],
+ tensors_expected[fBSliceOut]);
+ // stacked[:, 2, :] == c.
+ test::ExpectTensorEqual<float>(tensors_expected[fC],
+ tensors_expected[fCSliceOut]);
+
+ // stacked[:, 0:1, :] == expand_dims(a, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
+ tensors_expected[fASlice01Out]);
+
+ // stacked[:, :1, :] == expand_dims(a, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
+ tensors_expected[fASliceTo1Out]);
+
+ // stacked[:, 1:2, :] == expand_dims(b, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedB],
+ tensors_expected[fBSlice12Out]);
+ // stacked[:, 2:, :] == expand_dims(c, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedC],
+ tensors_expected[fCSlice2ToOut]);
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyRemoveStackStridedSliceSameAxis(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ for (const auto& node : output.node()) {
+ if (node.name() == "pa_slice_out") {
+ EXPECT_EQ(node.input(0), "a");
+ } else if (node.name() == "pb_slice_out") {
+ EXPECT_EQ(node.input(0), "b");
+ } else if (node.name() == "pc_slice_out") {
+ EXPECT_EQ(node.input(0), "c");
+ } else if (str_util::EndsWith(node.name(), "_out")) {
+ EXPECT_EQ(strings::StrCat(node.input(0), "_out"),
+ strings::StrCat(
+ "ArithmeticOptimizer/RemoveStackStridedSliceSameAxis_",
+ node.name()));
+ }
+ }
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+
+ // stacked[:, 0, :] == a.
+ test::ExpectTensorEqual<float>(tensors_expected[fA], tensors[fASliceOut]);
+
+ // stacked[:, 1, :] == b.
+ test::ExpectTensorEqual<float>(tensors_expected[fB], tensors[fBSliceOut]);
+ // stacked[:, 2, :] == c.
+ test::ExpectTensorEqual<float>(tensors_expected[fC], tensors[fCSliceOut]);
+
+ // stacked[:, 0:1, :] == expand_dims(a, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
+ tensors[fASlice01Out]);
+
+ // stacked[:, :1, :] == expand_dims(a, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedA],
+ tensors[fASliceTo1Out]);
+
+ // stacked[:, 1:2, :] == expand_dims(b, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedB],
+ tensors[fBSlice12Out]);
+ // stacked[:, 2:, :] == expand_dims(c, 1).
+ test::ExpectTensorEqual<float>(tensors_expected[fExpandedC],
+ tensors[fCSlice2ToOut]);
+}
+
} // namespace grappler
} // namespace tensorflow