diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_op_kernel.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 40 |
1 files changed, 30 insertions, 10 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2a9eaeee14..dd3498ef7a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } +Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, + Tensor** output) { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + if (expected_output_dtype(index) == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in its + // value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + *output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, *output)); + context_->set_output(index, **output); + } else { + TensorShape tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); + TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); + } + return Status::OK(); +} + void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { // Makes the host Tensor that will refer to the expression. Tensor* output = nullptr; - auto shape = builder()->GetShape(handle); - if (!shape.ok()) { - SetStatus(shape.status()); + auto shape_or = builder()->GetShape(handle); + if (!shape_or.ok()) { + SetStatus(shape_or.status()); return; } - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - TensorShape tensor_shape; - OP_REQUIRES_OK(context_, - XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape)); OP_REQUIRES_OK(context_, - context_->allocate_output(index, tensor_shape, &output)); + allocate_output(index, shape_or.ValueOrDie(), &output)); // The expression is stored in the tensor's data buffer. Fill in the // fields now. |