diff options
author | David Majnemer <majnemer@google.com> | 2018-09-20 14:29:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 14:34:28 -0700 |
commit | bf5324fd55a894ac00d10b7cfb2d26f3d9f7f5c9 (patch) | |
tree | 2be8a9982012897ec3ac755627e3149af21c6b4f | |
parent | 61ba909b0a82fa3745964f0eb2a7949b2249982e (diff) |
[XLA] Don't create mixed precision operations accidentally
The reshape we created change the element type unintentionally.
PiperOrigin-RevId: 213883142
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 17 |
2 files changed, 11 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 4ef1dffa73..75dae7a714 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -754,11 +754,12 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction( }; auto reshape_if_necessary = [&](HloInstruction* hlo) { + hlo = as_type(hlo, dot->shape().element_type()); if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { hlo = computation_->AddInstruction( HloInstruction::CreateReshape(dot->shape(), hlo)); } - return as_type(hlo, dot->shape().element_type()); + return hlo; }; auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) { diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 3fc1ba2427..2047f894b4 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -3233,17 +3233,18 @@ INSTANTIATE_TEST_CASE_P( class DotStrengthReductionTest : public AlgebraicSimplifierTest, public ::testing::WithParamInterface< - ::testing::tuple<int, int, int, bool, bool>> {}; + ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {}; TEST_P(DotStrengthReductionTest, DotStrengthReduction) { int m, k, n; bool transpose_lhs, transpose_rhs; - std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + PrimitiveType element_type; + std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam(); - Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); - Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); - Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); - Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); - Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k}); HloComputation::Builder builder(TestName()); auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( @@ -3285,7 +3286,7 @@ INSTANTIATE_TEST_CASE_P( DotStrengthReductionTestInstantiation, DotStrengthReductionTest, ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Values(1, 2), ::testing::Bool(), - ::testing::Bool())); + ::testing::Bool(), ::testing::Values(F32, BF16))); struct DotOfConcatTestSpec { int64 m; |