aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla/kernels/retval_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/retval_op.cc')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc19
1 files changed, 15 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index db7ea775e2..5be70a4ded 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -62,10 +63,20 @@ class RetvalOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
TensorShape shape = ctx->InputShape(0);
- TensorShape representation_shape =
- tc.is_entry_computation()
- ? tc.RepresentationShape(shape, ctx->input_type(0))
- : shape;
+ ctx->SetStatus(is_constant.status());
+ TensorShape representation_shape;
+ if (tc.is_entry_computation()) {
+ xla::StatusOr<TensorShape> shape_or_status =
+ tc.RepresentationShape(shape, ctx->input_type(0));
+ if (!shape_or_status.ok()) {
+ ctx->SetStatus(shape_or_status.status());
+ return;
+ } else {
+ representation_shape = shape_or_status.ValueOrDie();
+ }
+ } else {
+ representation_shape = shape;
+ }
xla::XlaOp output = input;
if (tc.is_entry_computation()) {