From cb057ea64032e551027c8f9058a9d28a258c9d6b Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Mon, 8 Oct 2018 15:42:17 -0700 Subject: [XLA] Make overly-specific ShapeUtil predicate a little more general. PiperOrigin-RevId: 216263039 --- tensorflow/compiler/xla/service/hlo_instruction_test.cc | 3 ++- tensorflow/compiler/xla/service/hlo_query.cc | 2 +- tensorflow/compiler/xla/shape_util.cc | 5 +++-- tensorflow/compiler/xla/shape_util.h | 5 ++++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index c1b7c3832b..d93351fe04 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -135,7 +135,8 @@ TEST_F(HloInstructionTest, BasicProperties) { auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo"); EXPECT_EQ(HloOpcode::kParameter, parameter->opcode()); - EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape())); + EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32)); + EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32)); EXPECT_EQ(0, parameter->operand_count()); } diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index 2a07b6fcbc..2d5197be9e 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -24,7 +24,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && - ShapeUtil::IsScalarF32(instruction->shape())) { + ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) { *out = instruction->literal().Get({}); return true; } diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7f0201942b..9267de3cfc 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -461,8 +461,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } -/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { - return shape.element_type() == F32 && Rank(shape) == 0; +/* static */ bool ShapeUtil::IsScalarWithElementType( + const Shape& shape, PrimitiveType element_type) { + return IsScalar(shape) && shape.element_type() == element_type; } namespace { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index d8bb27beae..73f541d505 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -312,7 +312,10 @@ class ShapeUtil { static bool IsEffectiveScalar(const Shape& shape) { return IsArray(shape) && TrueRank(shape) == 0; } - static bool IsScalarF32(const Shape& shape); + + // Returns whether "shape" is a scalar (array) with the given element_type. + static bool IsScalarWithElementType(const Shape& shape, + PrimitiveType element_type); // Extracts the size of the shape's dimension at dimension number // GetDimensionNumber(dimension_number). -- cgit v1.2.3