aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-03-27 15:37:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 15:40:36 -0700
commit05ddf373980fae94a2c73cf93161332484e102fd (patch)
treeed60ae638764dc82047656dc49cf455b2b88ff3e /tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
parent9a0b91023d8444cd4691be10b36ce469ca08058d (diff)
[XLA] Fold reduce-window(convert(pad(X))) into reduce-window(convert(X))
ReduceWindow operations are done in higher precision to avoid accumulation error. Convert operations can find their way between a ReduceWindow and a Pad which can prevent a Pad from combining with a ReduceWindow. Fix this by looking past the Convert while also checking that the Convert'd Pad's init value is identical to the reduce-window value. PiperOrigin-RevId: 190686175
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc85
1 files changed, 85 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 3b80a827bf..20c549562d 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2338,6 +2338,91 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
}
+// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
+// ReduceWindow(Convert(op), x).
+TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
+ HloModule module(TestName());
+ HloComputation::Builder builder(TestName());
+
+ // Create operand to the pad.
+ HloInstruction* parameter =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(BF16, {1, 2, 3, 4}), "p0"));
+
+ // Create the pad.
+ PaddingConfig padding = MakeNoPaddingConfig(4);
+ padding.mutable_dimensions(1)->set_edge_padding_low(1);
+ padding.mutable_dimensions(3)->set_edge_padding_high(2);
+
+ HloInstruction* pad_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
+ ShapeUtil::MakeShape(BF16, {1, 3, 3, 5}), parameter, pad_value, padding));
+
+ HloInstruction* convert =
+ builder.AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::ChangeElementType(pad->shape(), F32), pad));
+
+ // Create add computation.
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module.AddEmbeddedComputation(builder.Build());
+ }
+
+ // Create the reduce-window.
+ Window window;
+ for (int64 i = 0; i < ShapeUtil::Rank(pad->shape()); ++i) {
+ auto* dim = window.add_dimensions();
+ dim->set_size(1);
+ dim->set_padding_low(10);
+ dim->set_padding_high(100);
+ dim->set_window_dilation(1);
+ dim->set_base_dilation(1);
+ }
+ const Shape reduce_window_shape =
+ ShapeUtil::MakeShape(F32, {111, 113, 113, 115});
+ HloInstruction* reduce_init_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(5.0f)));
+ HloInstruction* reduce_window =
+ builder.AddInstruction(HloInstruction::CreateReduceWindow(
+ reduce_window_shape, convert, reduce_init_value, window,
+ add_computation));
+
+ // Build the computation and run the simplifier.
+ auto computation = module.AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root, reduce_window);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+
+ // Running simplification again should not result in any further changes.
+ ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+
+ // Verify the result
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::ReduceWindow(op::Convert(parameter), op::Constant()));
+ EXPECT_TRUE(ShapeUtil::Equal(root->shape(), reduce_window_shape))
+ << ShapeUtil::HumanString(root->shape()) << " vs "
+ << ShapeUtil::HumanString(reduce_window_shape);
+ EXPECT_EQ(root->window().dimensions(0).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(1).padding_low(), 11);
+ EXPECT_EQ(root->window().dimensions(2).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(3).padding_low(), 10);
+ EXPECT_EQ(root->window().dimensions(0).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(1).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(2).padding_high(), 100);
+ EXPECT_EQ(root->window().dimensions(3).padding_high(), 102);
+}
+
TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
HloComputation::Builder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {448, 2048, 1, 1});