diff options
author | 2018-06-29 14:52:53 -0700 | |
---|---|---|
committer | 2018-06-29 14:55:24 -0700 | |
commit | 234aade2123d7927822dd421cbcd219be25f105e (patch) | |
tree | 36f5cf9a78f97d0478e066de858fdf15affbfd77 /tensorflow | |
parent | 56679a3ee01e50d86db1f26479e58e1439587e08 (diff) |
[TF:XLA] Change the return type of ShapeRepresentationFn from TensorShape to StatusOr<TensorShape>.
PiperOrigin-RevId: 202711909
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/jit/xla_device_context.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/retval_op.cc | 19 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 16 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.h | 11 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 10 |
8 files changed, 58 insertions, 31 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 37005479dc..0188faaf51 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -56,9 +56,9 @@ XlaTransferManager::XlaTransferManager( transfer_as_literal_(transfer_as_literal), shape_representation_fn_(std::move(shape_representation_fn)) { if (!shape_representation_fn_) { - shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) { - return shape; - }; + shape_representation_fn_ = + [](const TensorShape& shape, + DataType dtype) -> xla::StatusOr<TensorShape> { return shape; }; } } @@ -119,8 +119,13 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); - TensorShape shape = shape_representation_fn_(device_tensor->shape(), - device_tensor->dtype()); + xla::StatusOr<TensorShape> shape_or_status = shape_representation_fn_( + device_tensor->shape(), device_tensor->dtype()); + if (!shape_or_status.ok()) { + done(shape_or_status.status()); + return; + } + TensorShape shape = shape_or_status.ValueOrDie(); if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer( device_tensor->dtype(), shape, client_, @@ -217,8 +222,9 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor, CHECK(xla_src && xla_dst) << "Missing destination tensor for device-to-device copy"; if (!xla_dst->has_shaped_buffer()) { - TensorShape shape = - shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()); + TF_ASSIGN_OR_RETURN( + TensorShape shape, + shape_representation_fn_(src_tensor.shape(), src_tensor.dtype())); TF_RETURN_IF_ERROR( xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_, stream_->parent()->device_ordinal())); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index aa9c0596d1..1a09ef643f 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -164,6 +164,7 @@ cc_library( "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index db7ea775e2..5be70a4ded 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -62,10 +63,20 @@ class RetvalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { TensorShape shape = ctx->InputShape(0); - TensorShape representation_shape = - tc.is_entry_computation() - ? tc.RepresentationShape(shape, ctx->input_type(0)) - : shape; + ctx->SetStatus(is_constant.status()); + TensorShape representation_shape; + if (tc.is_entry_computation()) { + xla::StatusOr<TensorShape> shape_or_status = + tc.RepresentationShape(shape, ctx->input_type(0)); + if (!shape_or_status.ok()) { + ctx->SetStatus(shape_or_status.status()); + return; + } else { + representation_shape = shape_or_status.ValueOrDie(); + } + } else { + representation_shape = shape; + } xla::XlaOp output = input; if (tc.is_entry_computation()) { diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0c98c20805..99535faedc 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -231,10 +231,13 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; case XlaCompiler::Argument::kParameter: { - TensorShape shape = - is_entry_computation - ? options_.shape_representation_fn(arg.shape, arg.type) - : arg.shape; + TensorShape shape; + if (is_entry_computation) { + TF_ASSIGN_OR_RETURN( + shape, options_.shape_representation_fn(arg.shape, arg.type)); + } else { + shape = arg.shape; + } return TensorShapeToXLAShape(arg.type, shape, xla_shape); } case XlaCompiler::Argument::kResource: { @@ -242,8 +245,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, switch (arg.resource_kind) { case XlaResource::kVariable: { - TensorShape representation_shape = - options_.shape_representation_fn(arg.shape, arg.type); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + options_.shape_representation_fn(arg.shape, arg.type)); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 80593eaca5..079c99797e 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" @@ -242,7 +243,8 @@ class XlaCompiler { std::shared_ptr<xla::XlaComputation> computation; }; - typedef std::function<TensorShape(const TensorShape&, DataType)> + typedef std::function<xla::StatusOr<TensorShape>(const TensorShape&, + DataType)> ShapeRepresentationFn; struct Options { // Name of the compilation device to use. It must be set by the caller. diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index d0b5606907..fd39a58ce6 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -66,8 +66,8 @@ XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool is_entry_computation, - const std::function<TensorShape(const TensorShape&, DataType)>* - shape_representation_fn) + const std::function<xla::StatusOr<TensorShape>( + const TensorShape&, DataType)>* shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), @@ -119,8 +119,8 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::RepresentationShape(const TensorShape& shape, - DataType type) const { +xla::StatusOr<TensorShape> XlaContext::RepresentationShape( + const TensorShape& shape, DataType type) const { return (*shape_representation_fn_)(shape, type); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 5960daaefd..38d8cd653c 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_client/xla_computation.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -47,8 +48,8 @@ class XlaContext : public ResourceBase { XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, bool is_entry_computation, - const std::function<TensorShape(const TensorShape&, DataType)>* - shape_representation_fn); + const std::function<xla::StatusOr<TensorShape>( + const TensorShape&, DataType)>* shape_representation_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -101,8 +102,8 @@ class XlaContext : public ResourceBase { // Returns the XLA shape to be used to represent a variable of TF `shape` // and `type`, or of an argument or return value of a top-level computation. - TensorShape RepresentationShape(const TensorShape& shape, - DataType type) const; + xla::StatusOr<TensorShape> RepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -160,7 +161,7 @@ class XlaContext : public ResourceBase { // should be represented in XLA. Parameters/return values will be shaped // according to this function, and reshaped back to/from their declared shapes // for computations. Must be non-null. - const std::function<TensorShape(const TensorShape&, DataType)>* + const std::function<xla::StatusOr<TensorShape>(const TensorShape&, DataType)>* shape_representation_fn_; // Cache of prebuilt computations indexed by their type. diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 0eabfb3a52..359cb4c467 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/dma_helper.h" namespace tensorflow { @@ -353,8 +354,9 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(variable->shape(), variable->type()); + TF_ASSIGN_OR_RETURN( + TensorShape representation_shape, + xla_context.RepresentationShape(variable->shape(), variable->type())); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -474,8 +476,8 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(shape, type); + TF_ASSIGN_OR_RETURN(TensorShape representation_shape, + xla_context.RepresentationShape(shape, type)); if (shape != representation_shape) { handle = xla::Reshape(handle, representation_shape.dim_sizes()); } |