aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/shape_util.h
diff options
context:
space:
mode:
authorGravatar Chris Leary <leary@google.com>2018-10-08 15:42:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 15:46:10 -0700
commitcb057ea64032e551027c8f9058a9d28a258c9d6b (patch)
treec0c1876cb9a22d64671ca52719bc6cf2ca0e8d81 /tensorflow/compiler/xla/shape_util.h
parenteb0f862ba60f41e8d0f06ceb6fc65f7f9905a25a (diff)
[XLA] Make overly-specific ShapeUtil predicate a little more general.
PiperOrigin-RevId: 216263039
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.h')
-rw-r--r--tensorflow/compiler/xla/shape_util.h5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index d8bb27beae..73f541d505 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -312,7 +312,10 @@ class ShapeUtil {
static bool IsEffectiveScalar(const Shape& shape) {
return IsArray(shape) && TrueRank(shape) == 0;
}
- static bool IsScalarF32(const Shape& shape);
+
+ // Returns whether "shape" is a scalar (array) with the given element_type.
+ static bool IsScalarWithElementType(const Shape& shape,
+ PrimitiveType element_type);
// Extracts the size of the shape's dimension at dimension number
// GetDimensionNumber(dimension_number).