aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Chris Leary <leary@google.com>2018-10-08 15:42:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 15:46:10 -0700
commitcb057ea64032e551027c8f9058a9d28a258c9d6b (patch)
treec0c1876cb9a22d64671ca52719bc6cf2ca0e8d81
parenteb0f862ba60f41e8d0f06ceb6fc65f7f9905a25a (diff)
[XLA] Make overly-specific ShapeUtil predicate a little more general.
PiperOrigin-RevId: 216263039
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_query.cc2
-rw-r--r--tensorflow/compiler/xla/shape_util.cc5
-rw-r--r--tensorflow/compiler/xla/shape_util.h5
4 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index c1b7c3832b..d93351fe04 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -135,7 +135,8 @@ TEST_F(HloInstructionTest, BasicProperties) {
auto parameter = HloInstruction::CreateParameter(1, r0f32_, "foo");
EXPECT_EQ(HloOpcode::kParameter, parameter->opcode());
- EXPECT_TRUE(ShapeUtil::IsScalarF32(parameter->shape()));
+ EXPECT_TRUE(ShapeUtil::IsScalarWithElementType(parameter->shape(), F32));
+ EXPECT_FALSE(ShapeUtil::IsScalarWithElementType(parameter->shape(), S32));
EXPECT_EQ(0, parameter->operand_count());
}
diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc
index 2a07b6fcbc..2d5197be9e 100644
--- a/tensorflow/compiler/xla/service/hlo_query.cc
+++ b/tensorflow/compiler/xla/service/hlo_query.cc
@@ -24,7 +24,7 @@ namespace hlo_query {
bool IsConstantR0F32(HloInstruction* instruction, float* out) {
if (instruction->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsScalarF32(instruction->shape())) {
+ ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) {
*out = instruction->literal().Get<float>({});
return true;
}
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 7f0201942b..9267de3cfc 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -461,8 +461,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}
-/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
- return shape.element_type() == F32 && Rank(shape) == 0;
+/* static */ bool ShapeUtil::IsScalarWithElementType(
+ const Shape& shape, PrimitiveType element_type) {
+ return IsScalar(shape) && shape.element_type() == element_type;
}
namespace {
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index d8bb27beae..73f541d505 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -312,7 +312,10 @@ class ShapeUtil {
static bool IsEffectiveScalar(const Shape& shape) {
return IsArray(shape) && TrueRank(shape) == 0;
}
- static bool IsScalarF32(const Shape& shape);
+
+ // Returns whether "shape" is a scalar (array) with the given element_type.
+ static bool IsScalarWithElementType(const Shape& shape,
+ PrimitiveType element_type);
// Extracts the size of the shape's dimension at dimension number
// GetDimensionNumber(dimension_number).