diff options
author | Chris Leary <leary@google.com> | 2018-10-08 15:42:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 15:46:10 -0700 |
commit | cb057ea64032e551027c8f9058a9d28a258c9d6b (patch) | |
tree | c0c1876cb9a22d64671ca52719bc6cf2ca0e8d81 /tensorflow/compiler/xla/shape_util.h | |
parent | eb0f862ba60f41e8d0f06ceb6fc65f7f9905a25a (diff) |
[XLA] Make overly-specific ShapeUtil predicate a little more general.
PiperOrigin-RevId: 216263039
Diffstat (limited to 'tensorflow/compiler/xla/shape_util.h')
-rw-r--r-- | tensorflow/compiler/xla/shape_util.h | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index d8bb27beae..73f541d505 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -312,7 +312,10 @@ class ShapeUtil { static bool IsEffectiveScalar(const Shape& shape) { return IsArray(shape) && TrueRank(shape) == 0; } - static bool IsScalarF32(const Shape& shape); + + // Returns whether "shape" is a scalar (array) with the given element_type. + static bool IsScalarWithElementType(const Shape& shape, + PrimitiveType element_type); // Extracts the size of the shape's dimension at dimension number // GetDimensionNumber(dimension_number). |