diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference.cc | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d624f548b1..fdc7f41759 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -463,6 +463,17 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::MakeShape(element_type, new_dimensions); } +/* static */ StatusOr<Shape> ShapeInference::InferTokenShape( + tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) { + for (const Shape* arg_shape : arg_shapes) { + if (arg_shape->element_type() != TOKEN) { + return InvalidArgument( + "Operands of token instructions must be TOKEN types."); + } + } + return ShapeUtil::MakeTokenShape(); +} + /* static */ StatusOr<Shape> ShapeInference::InferConvertShape( const Shape& operand_shape, PrimitiveType new_element_type) { auto old_element_type = operand_shape.element_type(); |