aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/xla_op_kernel.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/xla_op_kernel.cc')
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc40
1 files changed, 30 insertions, 10 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 2a9eaeee14..dd3498ef7a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -455,23 +455,43 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
return Status::OK();
}
+Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
+ Tensor** output) {
+ // The step's default allocator is the dummy XlaCompilationAllocator which
+ // simply allocates a metadata buffer to hold the expression to which it
+ // corresponds.
+ if (expected_output_dtype(index) == DT_VARIANT) {
+ // tensor_data() is not supported for variant Tensor (i.e.,
+ // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
+ // XlaExpression inside the Tensor's tensor_data() does not work for
+ // variant. Instead construct a uint8 tensor and store the expression in its
+ // value.
+ // TODO(jpienaar): This should be refactored to stop masquerading
+ // XlaExpressions as Tensors.
+ *output = new Tensor();
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(
+ context_->allocate_temp(DT_UINT8, tensor_shape, *output));
+ context_->set_output(index, **output);
+ } else {
+ TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
+ TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
+ }
+ return Status::OK();
+}
+
void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
// Makes the host Tensor that will refer to the expression.
Tensor* output = nullptr;
- auto shape = builder()->GetShape(handle);
- if (!shape.ok()) {
- SetStatus(shape.status());
+ auto shape_or = builder()->GetShape(handle);
+ if (!shape_or.ok()) {
+ SetStatus(shape_or.status());
return;
}
- // The step's default allocator is the dummy XlaCompilationAllocator which
- // simply allocates a metadata buffer to hold the expression to which it
- // corresponds.
- TensorShape tensor_shape;
- OP_REQUIRES_OK(context_,
- XLAShapeToTensorShape(shape.ValueOrDie(), &tensor_shape));
OP_REQUIRES_OK(context_,
- context_->allocate_output(index, tensor_shape, &output));
+ allocate_output(index, shape_or.ValueOrDie(), &output));
// The expression is stored in the tensor's data buffer. Fill in the
// fields now.