aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-09-20 14:29:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 14:34:28 -0700
commitbf5324fd55a894ac00d10b7cfb2d26f3d9f7f5c9 (patch)
tree2be8a9982012897ec3ac755627e3149af21c6b4f
parent61ba909b0a82fa3745964f0eb2a7949b2249982e (diff)
[XLA] Don't create mixed precision operations accidentally
The reshape we created change the element type unintentionally. PiperOrigin-RevId: 213883142
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc3
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc17
2 files changed, 11 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 4ef1dffa73..75dae7a714 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -754,11 +754,12 @@ StatusOr<bool> AlgebraicSimplifierVisitor::HandleDotStrengthReduction(
};
auto reshape_if_necessary = [&](HloInstruction* hlo) {
+ hlo = as_type(hlo, dot->shape().element_type());
if (!ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) {
hlo = computation_->AddInstruction(
HloInstruction::CreateReshape(dot->shape(), hlo));
}
- return as_type(hlo, dot->shape().element_type());
+ return hlo;
};
auto add_reduce_in_f32 = [&](HloInstruction* hlo, const int64 dim) {
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 3fc1ba2427..2047f894b4 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -3233,17 +3233,18 @@ INSTANTIATE_TEST_CASE_P(
class DotStrengthReductionTest
: public AlgebraicSimplifierTest,
public ::testing::WithParamInterface<
- ::testing::tuple<int, int, int, bool, bool>> {};
+ ::testing::tuple<int, int, int, bool, bool, PrimitiveType>> {};
TEST_P(DotStrengthReductionTest, DotStrengthReduction) {
int m, k, n;
bool transpose_lhs, transpose_rhs;
- std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam();
+ PrimitiveType element_type;
+ std::tie(m, k, n, transpose_lhs, transpose_rhs, element_type) = GetParam();
- Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n});
- Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
- Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m});
- Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n});
- Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k});
+ Shape dot_shape = ShapeUtil::MakeShape(element_type, {m, n});
+ Shape lhs_shape = ShapeUtil::MakeShape(element_type, {m, k});
+ Shape transposed_lhs_shape = ShapeUtil::MakeShape(element_type, {k, m});
+ Shape rhs_shape = ShapeUtil::MakeShape(element_type, {k, n});
+ Shape transposed_rhs_shape = ShapeUtil::MakeShape(element_type, {n, k});
HloComputation::Builder builder(TestName());
auto lhs = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -3285,7 +3286,7 @@ INSTANTIATE_TEST_CASE_P(
DotStrengthReductionTestInstantiation, DotStrengthReductionTest,
::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2),
::testing::Values(1, 2), ::testing::Bool(),
- ::testing::Bool()));
+ ::testing::Bool(), ::testing::Values(F32, BF16)));
struct DotOfConcatTestSpec {
int64 m;