aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/shape_inference.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/shape_inference.cc')
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index d624f548b1..fdc7f41759 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -463,6 +463,17 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
+/* static */ StatusOr<Shape> ShapeInference::InferTokenShape(
+ tensorflow::gtl::ArraySlice<const Shape*> arg_shapes) {
+ for (const Shape* arg_shape : arg_shapes) {
+ if (arg_shape->element_type() != TOKEN) {
+ return InvalidArgument(
+ "Operands of token instructions must be TOKEN types.");
+ }
+ }
+ return ShapeUtil::MakeTokenShape();
+}
+
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
const Shape& operand_shape, PrimitiveType new_element_type) {
auto old_element_type = operand_shape.element_type();