aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-29 14:52:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 14:55:24 -0700
commit234aade2123d7927822dd421cbcd219be25f105e (patch)
tree36f5cf9a78f97d0478e066de858fdf15affbfd77 /tensorflow
parent56679a3ee01e50d86db1f26479e58e1439587e08 (diff)
[TF:XLA] Change the return type of ShapeRepresentationFn from TensorShape to StatusOr<TensorShape>.
PiperOrigin-RevId: 202711909
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc20
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc19
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc16
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h4
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc8
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h11
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc10
8 files changed, 58 insertions, 31 deletions
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 37005479dc..0188faaf51 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -56,9 +56,9 @@ XlaTransferManager::XlaTransferManager(
transfer_as_literal_(transfer_as_literal),
shape_representation_fn_(std::move(shape_representation_fn)) {
if (!shape_representation_fn_) {
- shape_representation_fn_ = [](const TensorShape& shape, DataType dtype) {
- return shape;
- };
+ shape_representation_fn_ =
+ [](const TensorShape& shape,
+ DataType dtype) -> xla::StatusOr<TensorShape> { return shape; };
}
}
@@ -119,8 +119,13 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
- TensorShape shape = shape_representation_fn_(device_tensor->shape(),
- device_tensor->dtype());
+ xla::StatusOr<TensorShape> shape_or_status = shape_representation_fn_(
+ device_tensor->shape(), device_tensor->dtype());
+ if (!shape_or_status.ok()) {
+ done(shape_or_status.status());
+ return;
+ }
+ TensorShape shape = shape_or_status.ValueOrDie();
if (!xla_tensor->has_shaped_buffer()) {
Status s = xla_tensor->AllocateShapedBuffer(
device_tensor->dtype(), shape, client_,
@@ -217,8 +222,9 @@ void XlaTransferManager::CopyDeviceTensorToDevice(const Tensor& src_tensor,
CHECK(xla_src && xla_dst)
<< "Missing destination tensor for device-to-device copy";
if (!xla_dst->has_shaped_buffer()) {
- TensorShape shape =
- shape_representation_fn_(src_tensor.shape(), src_tensor.dtype());
+ TF_ASSIGN_OR_RETURN(
+ TensorShape shape,
+ shape_representation_fn_(src_tensor.shape(), src_tensor.dtype()));
TF_RETURN_IF_ERROR(
xla_dst->AllocateShapedBuffer(src_tensor.dtype(), shape, client_,
stream_->parent()->device_ordinal()));
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index aa9c0596d1..1a09ef643f 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -164,6 +164,7 @@ cc_library(
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index db7ea775e2..5be70a4ded 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -62,10 +63,20 @@ class RetvalOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
TensorShape shape = ctx->InputShape(0);
- TensorShape representation_shape =
- tc.is_entry_computation()
- ? tc.RepresentationShape(shape, ctx->input_type(0))
- : shape;
+ ctx->SetStatus(is_constant.status());
+ TensorShape representation_shape;
+ if (tc.is_entry_computation()) {
+ xla::StatusOr<TensorShape> shape_or_status =
+ tc.RepresentationShape(shape, ctx->input_type(0));
+ if (!shape_or_status.ok()) {
+ ctx->SetStatus(shape_or_status.status());
+ return;
+ } else {
+ representation_shape = shape_or_status.ValueOrDie();
+ }
+ } else {
+ representation_shape = shape;
+ }
xla::XlaOp output = input;
if (tc.is_entry_computation()) {
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 0c98c20805..99535faedc 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -231,10 +231,13 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case";
case XlaCompiler::Argument::kParameter: {
- TensorShape shape =
- is_entry_computation
- ? options_.shape_representation_fn(arg.shape, arg.type)
- : arg.shape;
+ TensorShape shape;
+ if (is_entry_computation) {
+ TF_ASSIGN_OR_RETURN(
+ shape, options_.shape_representation_fn(arg.shape, arg.type));
+ } else {
+ shape = arg.shape;
+ }
return TensorShapeToXLAShape(arg.type, shape, xla_shape);
}
case XlaCompiler::Argument::kResource: {
@@ -242,8 +245,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
switch (arg.resource_kind) {
case XlaResource::kVariable: {
- TensorShape representation_shape =
- options_.shape_representation_fn(arg.shape, arg.type);
+ TF_ASSIGN_OR_RETURN(
+ TensorShape representation_shape,
+ options_.shape_representation_fn(arg.shape, arg.type));
return TensorShapeToXLAShape(arg.type, representation_shape,
xla_shape);
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 80593eaca5..079c99797e 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -242,7 +243,8 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation;
};
- typedef std::function<TensorShape(const TensorShape&, DataType)>
+ typedef std::function<xla::StatusOr<TensorShape>(const TensorShape&,
+ DataType)>
ShapeRepresentationFn;
struct Options {
// Name of the compilation device to use. It must be set by the caller.
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index d0b5606907..fd39a58ce6 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -66,8 +66,8 @@ XlaContext::XlaContext(
XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
- const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn)
+ const std::function<xla::StatusOr<TensorShape>(
+ const TensorShape&, DataType)>* shape_representation_fn)
: compiler_(compiler),
builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls),
@@ -119,8 +119,8 @@ Status XlaContext::CreateResource(
return Status::OK();
}
-TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
- DataType type) const {
+xla::StatusOr<TensorShape> XlaContext::RepresentationShape(
+ const TensorShape& shape, DataType type) const {
return (*shape_representation_fn_)(shape, type);
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 5960daaefd..38d8cd653c 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -47,8 +48,8 @@ class XlaContext : public ResourceBase {
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
bool is_entry_computation,
- const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn);
+ const std::function<xla::StatusOr<TensorShape>(
+ const TensorShape&, DataType)>* shape_representation_fn);
// Virtual method defined by ResourceBase.
string DebugString() override;
@@ -101,8 +102,8 @@ class XlaContext : public ResourceBase {
// Returns the XLA shape to be used to represent a variable of TF `shape`
// and `type`, or of an argument or return value of a top-level computation.
- TensorShape RepresentationShape(const TensorShape& shape,
- DataType type) const;
+ xla::StatusOr<TensorShape> RepresentationShape(const TensorShape& shape,
+ DataType type) const;
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@@ -160,7 +161,7 @@ class XlaContext : public ResourceBase {
// should be represented in XLA. Parameters/return values will be shaped
// according to this function, and reshaped back to/from their declared shapes
// for computations. Must be non-null.
- const std::function<TensorShape(const TensorShape&, DataType)>*
+ const std::function<xla::StatusOr<TensorShape>(const TensorShape&, DataType)>*
shape_representation_fn_;
// Cache of prebuilt computations indexed by their type.
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 0eabfb3a52..359cb4c467 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
namespace tensorflow {
@@ -353,8 +354,9 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
}
XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(variable->shape(), variable->type());
+ TF_ASSIGN_OR_RETURN(
+ TensorShape representation_shape,
+ xla_context.RepresentationShape(variable->shape(), variable->type()));
if (representation_shape == variable->shape()) {
*value = variable->value();
} else {
@@ -474,8 +476,8 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(shape, type);
+ TF_ASSIGN_OR_RETURN(TensorShape representation_shape,
+ xla_context.RepresentationShape(shape, type));
if (shape != representation_shape) {
handle = xla::Reshape(handle, representation_shape.dim_sizes());
}