aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.cc')
-rw-r--r--tensorflow/compiler/xla/shape_util.cc5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 7f0201942b..9267de3cfc 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -461,8 +461,9 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
return ShapeUtil::IsArray(shape) && ElementsIn(shape) == 0;
}
-/* static */ bool ShapeUtil::IsScalarF32(const Shape& shape) {
- return shape.element_type() == F32 && Rank(shape) == 0;
+/* static */ bool ShapeUtil::IsScalarWithElementType(
+ const Shape& shape, PrimitiveType element_type) {
+ return IsScalar(shape) && shape.element_type() == element_type;
}
namespace {