diff options
author | 2018-10-08 15:42:17 -0700 | |
---|---|---|
committer | 2018-10-08 15:46:10 -0700 | |
commit | cb057ea64032e551027c8f9058a9d28a258c9d6b (patch) | |
tree | c0c1876cb9a22d64671ca52719bc6cf2ca0e8d81 /tensorflow/compiler/xla/service | |
parent | eb0f862ba60f41e8d0f06ceb6fc65f7f9905a25a (diff) |
[XLA] Make overly-specific ShapeUtil predicate a little more general.
PiperOrigin-RevId: 216263039
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction_test.cc | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_query.cc | 2 |
2 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index c1b7c3832b..d93351fe04 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -135,7 +135,8 @@ TEST_F(HloInstructionTest, BasicProperties) { auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo"); EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); - EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape())); + EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); + EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); EXPECT_EQ(0, parameter->operand_count()); } diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2a07b6fcbc..2d5197be9e 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -24,7 +24,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalarF32(instruction->shape())) { + ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) { *out = instruction->literal().Get<float>({}); return true; } |