diff options
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/shape_util.cc | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7f0201942b..9267de3cfc 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -461,8 +461,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0; } -/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) { - return shape.element_type() == F32 && Rank(shape) == 0; +/* static */ bool ShapeUtil::IsScalarWithElementType( + const Shape& shape, PrimitiveType element_type) { + return IsScalar(shape) && shape.element_type() == element_type; } namespace { |