aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-07-09 16:07:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 16:11:41 -0700
commit3da378059e7595700a89e4d4cf29a48c5748ea18 (patch)
tree462e759f8affcaec3ee83fe317d6dd03ba930c9e /tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
parentbc2674d09efbd87ae81ae41b81f1d152f37fac2a (diff)
[XLA] BFloat16 propagation: add de-aliasing copies before while loop inputs.
The while loop input and output alias each other, so as long as an input is also used by other ops that could not use BF16, the propagation pass could not change such an input/ouput to BF16 even if all uses in the while loop could use BF16. Add copies for each while loop operand. This increases the chance to propagate BF16 through the while loop; if some of these copies do not help, they will remain same-shape copies and be removed at the end. This can sometimes increase HBM usage because both BF16 and F32 copies are alive, and can sometimes reduce HBM usage. PiperOrigin-RevId: 203848348
Diffstat (limited to 'tensorflow/compiler/xla/service/bfloat16_propagation_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc12
1 files changed, 4 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index 23aa83ea88..aeafb25ad7 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -240,12 +240,10 @@ TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
EXPECT_TRUE(PropagatePrecision(module.get()));
EXPECT_EQ(computation->root_instruction(), dot);
- EXPECT_TRUE(OutputsBF16(add0));
EXPECT_TRUE(OutputsBF16(add1));
EXPECT_TRUE(OutputsBF16(lhs));
- // rhs is a get-tuple-element, which does not define a buffer, but its shape
- // should also be adjusted accordingly.
- EXPECT_TRUE(OutputsBF16(rhs));
+
+ // add0 and rhs have been eliminated by simplification and DCE.
}
// Tests that a non-fusion computation's root should not be changed.
@@ -734,10 +732,8 @@ TEST_F(BFloat16PropagationTest, NoopConversionRemoved) {
EXPECT_TRUE(PropagatePrecision(module.get()));
EXPECT_EQ(computation->root_instruction(), add2);
- EXPECT_EQ(add2->operand(0), gte0);
- EXPECT_EQ(add2->operand(1), gte1);
- EXPECT_EQ(gte0->shape().element_type(), BF16);
- EXPECT_EQ(gte1->shape().element_type(), BF16);
+ EXPECT_EQ(add2->operand(0), add0);
+ EXPECT_EQ(add2->operand(1), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
EXPECT_EQ(add1->shape().element_type(), BF16);
}