aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-02-02 11:31:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 11:36:32 -0800
commit22116459b258d5753aa76410ab6f4d3cbc928a5a (patch)
treed2f906ab2dafa09b6d40028a951c68c4be409c85
parent224874002f93fec471e401488e23d97d4f36c4fc (diff)
[TF:XLA] Improve/refactor the handling of resource types/shapes.
Previously we used an xla::Shape to track the shape of a resource (Variable, TensorArray, Stack) shape. The xla::Shape described how the resource was represented to XLA, e.g., as a (buffer, size) pair for a Stack resource. Instead, separate the TensorFlow abstract shape representation from the XLA shape representation and track it separately. This leads to simpler and more readable code. PiperOrigin-RevId: 184310694
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc4
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc11
-rw-r--r--tensorflow/compiler/tf2xla/graph_compiler.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/stack_ops.cc20
-rw-r--r--tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc11
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc33
-rw-r--r--tensorflow/compiler/tf2xla/kernels/training_ops.cc113
-rw-r--r--tensorflow/compiler/tf2xla/kernels/variable_ops.cc51
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc33
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla.cc4
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc133
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h26
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc26
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc12
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h10
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc30
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h11
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.cc139
-rw-r--r--tensorflow/compiler/tf2xla/xla_resource.h38
19 files changed, 385 insertions, 324 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 17ae2bb25c..6353149e4a 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -376,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES(ctx,
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));
- TensorShape write_shape;
- OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape));
gpu::DeviceMemoryBase buffer = output->buffer({output_num});
@@ -399,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
// Looks up the owning Tensor by buffer address.
OP_REQUIRES_OK(
- ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape,
+ ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape,
variable->tensor()));
++output_num;
}
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 21d3a54f1b..6d854a920e 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args,
XlaCompiler::Argument& arg = (*args)[input_num];
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input.dtype();
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
+ arg.shape = input.shape();
arg.constant_value = input;
++input_num;
}
@@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args,
arg.constant_value = input;
}
arg.type = input.dtype();
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape));
+ arg.shape = input.shape();
++input_num;
}
@@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args,
if (variable_args[variable_id].present) {
const Tensor& value = variable_args[variable_id].value;
arg.type = value.dtype();
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape));
+ arg.shape = value.shape();
arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since
@@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args,
// uninitialized variables.
arg.initialized = false;
arg.type = DT_INVALID;
- arg.shape = xla::Shape();
+ arg.shape = TensorShape();
}
++input_num;
}
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 02215b5112..1418d95956 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -60,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
for (int i = 0; i < args->size(); ++i) {
XlaCompiler::Argument& arg = (*args)[i];
arg.type = ctx->input_type(i);
-
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
+ arg.shape = ctx->InputShape(i);
if (arg.type == DT_RESOURCE) {
return errors::InvalidArgument(
diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
index d77fb768ef..1a78c7ab9b 100644
--- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc
@@ -77,10 +77,8 @@ Status MaybeInitializeStack(xla::ComputationBuilder* builder,
// Stack has not been initialized.
xla::ComputationDataHandle zero =
XlaHelpers::Zero(builder, resource->type());
- TF_RETURN_IF_ERROR(resource->SetValue(
- dtype,
- builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()),
- builder->ConstantR0<int32>(0)})));
+ TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
+ TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
// Checks the expected shape matches the actual shape.
TensorShape actual_shape;
@@ -119,8 +117,8 @@ class StackOp : public XlaOpKernel {
string name = strings::StrCat("Stack: ", stack_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_,
- value, &resource));
- resource->set_tensor_array_size(size);
+ TensorShape(), value, /*tensor_array_size=*/size,
+ /*tensor_array_gradients=*/{}, &resource));
ctx->SetResourceOutput(0, resource);
}
@@ -164,11 +162,9 @@ class StackPushOp : public XlaOpKernel {
// TODO(phawkins): We don't check the index is in bounds --- there is no
// error mechanism in XLA.
- OP_REQUIRES_OK(
- ctx,
- resource->SetValue(
- dtype_, b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices),
- b->Add(index, b->ConstantR0<int32>(1))})));
+ OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple(
+ {b->DynamicUpdateSlice(ta, update, start_indices),
+ b->Add(index, b->ConstantR0<int32>(1))})));
ctx->SetOutput(0, value);
}
@@ -208,7 +204,7 @@ class StackPopOp : public XlaOpKernel {
xla::ComputationDataHandle index = b->GetTupleElement(state, 1);
index = b->Sub(index, b->ConstantR0<int32>(1));
- OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, b->Tuple({ta, index})));
+ OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index})));
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
auto start_indices =
diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
index f0525a5fb8..91c169428c 100644
--- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc
@@ -231,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -252,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_,
&strides_tensor));
- DataType lhs_type;
TensorShape lhs_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape));
+ xla::ComputationDataHandle lhs;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs));
const TensorShape rhs_shape = ctx->InputShape(4);
@@ -282,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel {
" does not match r-value shape ", rhs_shape.DebugString(),
". Automatic broadcasting not yet implemented."));
- xla::ComputationDataHandle lhs;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs));
-
xla::ComputationDataHandle rhs = ctx->Input(4);
gtl::InlinedVector<int64, 4> dimensions_to_reverse;
@@ -320,13 +318,14 @@ class StridedSliceAssignOp : public XlaOpKernel {
lhs, rhs, ctx->builder()->ConstantR1<int64>(slice_begin));
}
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs));
}
private:
int32 begin_mask_, end_mask_;
int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_;
DataType index_type_;
+ DataType dtype_;
};
REGISTER_XLA_OP(Name("ResourceStridedSliceAssign")
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 9224072a3c..7cf9b796b9 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -62,15 +62,13 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
TF_RET_CHECK(resource->tensor_array_size() >= 0)
<< resource->name() << " size " << resource->tensor_array_size();
- TensorShape ta_shape;
- ta_shape.AddDim(resource->tensor_array_size());
- ta_shape.AppendShape(elem_shape);
if (!resource->initialized()) {
xla::ComputationDataHandle zero =
XlaHelpers::Zero(builder, resource->type());
- TF_RETURN_IF_ERROR(resource->SetValue(
- dtype, builder->Broadcast(zero, ta_shape.dim_sizes())));
+
+ TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape));
+ TF_RETURN_IF_ERROR(resource->SetZeroValue(builder));
} else {
// Checks the elem_shape matches the TensorArray shape.
auto shape_or_status = builder->GetShape(resource->value());
@@ -80,6 +78,10 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
TensorShape shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+
+ TensorShape ta_shape;
+ ta_shape.AddDim(resource->tensor_array_size());
+ ta_shape.AppendShape(elem_shape);
if (ta_shape != shape) {
return errors::InvalidArgument(
"Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ",
@@ -114,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
Status GetTensorArrayShape(const XlaResource* resource,
xla::ComputationBuilder* builder,
TensorShape* shape) {
- TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
- if (shape->dims() < 1) {
- return errors::InvalidArgument("TensorArray rank must be >= 1");
- }
+ *shape = resource->shape();
+ shape->InsertDim(0, resource->tensor_array_size());
return Status::OK();
}
@@ -160,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel {
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
xla::ComputationDataHandle value;
+ TensorShape shape;
if (element_shape_.IsFullyDefined()) {
- TensorShape shape;
CHECK(element_shape_.AsTensorShape(&shape));
TensorShape ta_shape;
ta_shape.AddDim(size);
@@ -175,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel {
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
OP_REQUIRES_OK(
ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
- dtype_, value, &var));
- var->set_tensor_array_size(size);
+ dtype_, shape, value, /*tensor_array_size=*/size,
+ /*tensor_array_gradients=*/{}, &var));
ctx->SetResourceOutput(0, var);
Tensor flow(DT_FLOAT, TensorShape({}));
@@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
- OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, written));
+ OP_REQUIRES_OK(ctx, resource->SetValue(written));
ctx->SetOutput(0, flow);
}
@@ -421,7 +421,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
}
}
- OP_REQUIRES_OK(ctx, resource->SetValue(dtype_, ta));
+ OP_REQUIRES_OK(ctx, resource->SetValue(ta));
ctx->SetOutput(0, flow);
}
@@ -525,9 +525,8 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- OP_REQUIRES_OK(
- ctx, resource->SetValue(
- dtype_, b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()))));
+ OP_REQUIRES_OK(ctx, resource->SetValue(b->Add(
+ ta, b->Reshape(value, ta_shape.dim_sizes()))));
ctx->SetOutput(0, flow);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
index 5534d1bfa1..f750f7003b 100644
--- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc
@@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
xla::ComputationBuilder* b = ctx->builder();
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+ DataType type = ctx->input_type(1);
+ TensorShape var_shape;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle));
+
+ TensorShape alpha_shape = ctx->InputShape(1);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape),
+ errors::InvalidArgument("alpha is not a scalar: ",
+ alpha_shape.DebugString()));
+
+ TensorShape delta_shape = ctx->InputShape(2);
+ OP_REQUIRES(
+ ctx, var_shape.IsSameSize(delta_shape),
+ errors::InvalidArgument("var and delta do not have the same shape: ",
+ var_shape.DebugString(), " vs ",
+ delta_shape.DebugString()));
+
handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
@@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel {
DataType type = ctx->input_type(2);
- DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
-
- OP_REQUIRES(
- ctx, type == var_type && type == accum_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyMomentum must match: ",
- DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
- DataTypeString(accum_type)));
+ xla::ComputationDataHandle var, accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
@@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel {
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
- xla::ComputationDataHandle var, accum;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
-
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
xla::ComputationDataHandle momentum = ctx->Input(4);
@@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel {
DataType type = ctx->input_type(2);
- DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
-
- OP_REQUIRES(
- ctx, type == var_type && type == accum_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyAdagrad must match: ",
- DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
- DataTypeString(accum_type)));
+ xla::ComputationDataHandle var, accum;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
@@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle var, accum;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
@@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- DataType var_type, m_type, v_type;
TensorShape var_shape, m_shape, v_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape));
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape));
-
- OP_REQUIRES(
- ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyRMSProp must match: ",
- DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ",
- DataTypeString(m_type), " vs. ", DataTypeString(v_type)));
+ xla::ComputationDataHandle var, m, v;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v));
TensorShape beta1_power_shape = ctx->InputShape(3);
TensorShape beta2_power_shape = ctx->InputShape(4);
@@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle var, m, v;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v));
xla::ComputationDataHandle beta1_power = ctx->Input(3);
xla::ComputationDataHandle beta2_power = ctx->Input(4);
xla::ComputationDataHandle lr = ctx->Input(5);
@@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel {
DataType type = ctx->input_type(3);
- DataType var_type, ms_type, mom_type;
TensorShape var_shape, ms_shape, mom_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape));
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape));
-
- OP_REQUIRES(
- ctx, type == var_type && type == ms_type && type == mom_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyRMSProp must match: ",
- DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ",
- DataTypeString(ms_type), " vs. ", DataTypeString(mom_type)));
+ xla::ComputationDataHandle var, ms, mom;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom));
TensorShape lr_shape = ctx->InputShape(3);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
@@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel {
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
- xla::ComputationDataHandle var, ms, mom;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom));
xla::ComputationDataHandle lr = ctx->Input(3);
xla::ComputationDataHandle rho = ctx->Input(4);
xla::ComputationDataHandle momentum = ctx->Input(5);
@@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
bool has_l2_shrinkage) {
xla::ComputationBuilder* b = ctx->builder();
- DataType var_type, accum_type, linear_type;
TensorShape var_shape, accum_shape, linear_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
- OP_REQUIRES_OK(ctx,
- ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape));
-
- OP_REQUIRES(
- ctx, dtype == var_type && dtype == accum_type && dtype == linear_type,
- errors::InvalidArgument(
- "Types of variable arguments to ResourceApplyFtrlV2 must match: ",
- DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ",
- DataTypeString(accum_type), " and ", DataTypeString(linear_type)));
+ xla::ComputationDataHandle var, accum, linear;
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
@@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype,
errors::InvalidArgument("lr_power is not a scalar: ",
lr_power_shape.DebugString()));
- xla::ComputationDataHandle var, accum, linear;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear));
xla::ComputationDataHandle grad = ctx->Input(3);
xla::ComputationDataHandle lr = ctx->Input(4);
xla::ComputationDataHandle l1 = ctx->Input(5);
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index 68847ae7a2..e4079ebf0b 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -33,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel {
public:
explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- xla::ComputationDataHandle handle;
- bool initialized = ctx->ReadVariableInput(0, &handle).ok();
- ctx->SetOutput(0, ctx->builder()->ConstantR0<bool>(initialized));
+ XlaResource* variable;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
+ ctx->SetOutput(0,
+ ctx->builder()->ConstantR0<bool>(variable->initialized()));
}
};
REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
class ReadVariableOp : public XlaOpKernel {
public:
- explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ }
+
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+ OP_REQUIRES_OK(
+ ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
ctx->SetOutput(0, handle);
}
+
+ private:
+ DataType dtype_;
};
REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp);
@@ -65,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel {
public:
explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
+ DataType type = ctx->input_type(1);
xla::ComputationDataHandle handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Add(handle, ctx->Input(1));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
@@ -79,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel {
public:
explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
+ DataType type = ctx->input_type(1);
xla::ComputationDataHandle handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
+ OP_REQUIRES_OK(ctx,
+ ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
handle = ctx->builder()->Sub(handle, ctx->Input(1));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
+ OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
}
};
REGISTER_XLA_OP(
@@ -95,28 +107,19 @@ class ResourceGatherOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* builder = ctx->builder();
- // Get the shape of the resource tensor.
- TensorShape resource_shape;
- DataType resource_dtype;
- OP_REQUIRES_OK(
- ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape));
-
- DataType expected_output_dtype = ctx->expected_output_dtype(0);
- OP_REQUIRES(ctx, resource_dtype == expected_output_dtype,
- errors::InvalidArgument(
- "Variable dtype is ", DataTypeString(resource_dtype),
- " but expected output dtype is ",
- DataTypeString(expected_output_dtype), "."));
+ DataType type = ctx->expected_output_dtype(0);
+ TensorShape resource_shape;
xla::ComputationDataHandle resource_handle;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle));
+ OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
+ &resource_handle));
auto indices = ctx->Input(1);
auto indices_shape = ctx->InputShape(1);
DataType index_type = ctx->input_type(1);
xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice(
- ctx, resource_handle, resource_shape, indices, indices_shape, 0,
- resource_dtype, index_type, builder);
+ ctx, resource_handle, resource_shape, indices, indices_shape, 0, type,
+ index_type, builder);
ctx->SetOutput(0, gather);
}
};
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 4a711e4d9b..0ff1b65ae9 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -58,9 +58,8 @@ Status MakeXlaCompilerArgumentsFromInputs(
}
arg.type = resource->type();
- if (arg.initialized) {
- TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape));
- } else {
+ arg.shape = resource->shape();
+ if (!arg.initialized) {
*has_uninitialized_vars = true;
}
arg.tensor_array_size = resource->tensor_array_size();
@@ -70,14 +69,13 @@ Status MakeXlaCompilerArgumentsFromInputs(
arg.name = resource->name();
VLOG(2) << " resource " << resource->name()
<< " type: " << DataTypeString(arg.type)
- << " shape: " << xla::ShapeUtil::HumanString(arg.shape)
+ << " shape: " << arg.shape.DebugString()
<< " initialized: " << arg.initialized;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
arg.type = ctx->input_type(i);
- TF_RETURN_IF_ERROR(
- TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape));
+ arg.shape = ctx->InputShape(i);
}
}
return Status::OK();
@@ -154,17 +152,14 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
XlaCompiler::Argument& arg = arguments[update.input_index];
if (!arg.initialized) {
VLOG(2) << "Update shape for argument " << update.input_index << " "
- << xla::ShapeUtil::HumanString(update.shape);
+ << update.shape.DebugString();
arg.initialized = true;
- xla::Shape shape = update.shape;
- if (!update.tensor_array_gradients_accessed.empty()) {
- shape = xla::ShapeUtil::GetTupleElementShape(shape, 0);
- }
- std::unique_ptr<xla::Literal> zero =
- xla::Literal::CreateFromShape(shape);
- OP_REQUIRES_OK(ctx, resource->SetValue(
- update.type, builder->ConstantLiteral(*zero)));
+ arg.shape = update.shape;
+ OP_REQUIRES_OK(ctx,
+ resource->SetTypeAndShape(update.type, update.shape));
+
+ OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder));
}
// Add any TensorArray gradients touched by the body to the enclosing
@@ -182,9 +177,6 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
for (const auto& gradient : resource->tensor_array_gradients()) {
arg.tensor_array_gradients.insert(gradient.first);
}
-
- // Recompute the argument shape.
- OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape));
}
// Recompile the body with the "correct" resource shapes.
VLOG(1) << "Recompiling body with corrected resource shapes";
@@ -292,13 +284,12 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx,
resource->SetFromPack(
arguments[update.input_index].tensor_array_gradients,
- builder->GetTupleElement(while_result, pos),
- /*reset_initial_values=*/false, builder));
+ builder->GetTupleElement(while_result, pos), builder));
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name() << " modified: " << update.modified
<< " type: " << DataTypeString(update.type)
- << " shape: " << xla::ShapeUtil::HumanString(update.shape);
+ << " shape: " << update.shape.DebugString();
// Copies the identity of the resource variable from input to output
// unchanged, even if the variable was not modified.
ctx->op_kernel_context()->set_output(
diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc
index 906f229043..6051d7dffd 100644
--- a/tensorflow/compiler/tf2xla/tf2xla.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla.cc
@@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph,
XlaCompiler::Argument arg;
arg.kind = XlaCompiler::Argument::kParameter;
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type));
- TensorShape shape;
- TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape));
- TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape));
+ TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape));
TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name));
xla_args->push_back(arg);
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 69b265436b..c5b4ec5b15 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types,
bool XlaCompiler::Argument::operator==(
const XlaCompiler::Argument& other) const {
- if (std::tie(kind, resource_kind, type, name, tensor_array_size,
+ if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size,
tensor_array_gradients) !=
std::tie(other.kind, other.resource_kind, other.type, other.name,
- other.tensor_array_size, other.tensor_array_gradients)) {
+ other.initialized, other.tensor_array_size,
+ other.tensor_array_gradients)) {
return false;
}
- if (!xla::ShapeUtil::Equal(shape, other.shape)) {
+ if (shape != other.shape) {
return false;
}
if (constant_value.shape() != other.constant_value.shape()) {
@@ -230,6 +231,64 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
return Status::OK();
}
+// Computes the XLA shape for argument 'arg'.
+/*static*/ Status XlaCompiler::XLAShapeForArgument(
+ const XlaCompiler::Argument& arg, xla::Shape* xla_shape) {
+ switch (arg.kind) {
+ case XlaCompiler::Argument::kConstant:
+ return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
+ xla_shape);
+ case XlaCompiler::Argument::kParameter:
+ return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
+ case XlaCompiler::Argument::kResource: {
+ TF_RET_CHECK(arg.initialized);
+
+ switch (arg.resource_kind) {
+ case XlaResource::kVariable:
+ return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
+ case XlaResource::kTensorArray: {
+ if (arg.tensor_array_size < 0) {
+ return errors::InvalidArgument(
+ "Negative tensor_array_size in XLAShapeForArgument");
+ }
+ TensorShape shape;
+ shape.AddDim(arg.tensor_array_size);
+ shape.AppendShape(arg.shape);
+ TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape));
+
+ if (!arg.tensor_array_gradients.empty()) {
+ std::vector<xla::Shape> tuple_shape(
+ arg.tensor_array_gradients.size() + 1, *xla_shape);
+ *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape);
+ }
+ return Status::OK();
+ }
+ case XlaResource::kStack: {
+ if (arg.tensor_array_size < 0) {
+ return errors::InvalidArgument(
+ "Negative tensor_array_size in XLAShapeForArgument");
+ }
+ TensorShape shape;
+ shape.AddDim(arg.tensor_array_size);
+ shape.AppendShape(arg.shape);
+ xla::Shape buffer_shape;
+ TF_RETURN_IF_ERROR(
+ TensorShapeToXLAShape(arg.type, shape, &buffer_shape));
+ *xla_shape = xla::ShapeUtil::MakeTupleShape(
+ {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})});
+ return Status::OK();
+ }
+
+ case XlaResource::kInvalid:
+ return errors::Internal(
+ "Invalid resource type in XLAShapeForArgument()");
+ }
+ }
+ case XlaCompiler::Argument::kInvalid:
+ return errors::Internal("Invalid argument type in XLAShapeForArgument()");
+ }
+}
+
namespace {
Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
@@ -275,8 +334,9 @@ Status BuildArguments(const Graph& graph,
// Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters.
- std::vector<int> parameters, resources;
- parameters.reserve(args.size());
+ input_mapping->clear();
+ input_mapping->reserve(args.size());
+ std::vector<int> resources;
resources.reserve(args.size());
// Fills in constant arguments, and computes non-constant argument order.
@@ -290,18 +350,20 @@ Status BuildArguments(const Graph& graph,
// TODO(phawkins): this code assumes that resource arguments do not
// alias.
XlaResource* resource;
- TF_RETURN_IF_ERROR(
- context->CreateResource(arg.resource_kind, i, arg.name, arg.type,
- xla::ComputationDataHandle(), &resource));
- resource->set_tensor_array_size(arg.tensor_array_size);
+ TF_RETURN_IF_ERROR(context->CreateResource(
+ arg.resource_kind, i, arg.name, arg.type, arg.shape,
+ xla::ComputationDataHandle(),
+ /*tensor_array_size=*/arg.tensor_array_size,
+ /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
arg_expression.set_resource(resource);
if (arg.initialized) {
resources.push_back(i);
}
break;
- case XlaCompiler::Argument::kParameter:
- parameters.push_back(i);
+ case XlaCompiler::Argument::kParameter: {
+ input_mapping->push_back(i);
break;
+ }
case XlaCompiler::Argument::kConstant:
arg_expression.set_constant_value(arg.constant_value);
break;
@@ -312,19 +374,17 @@ Status BuildArguments(const Graph& graph,
// Append parameters containing variable values after the other runtime
// parameters.
- parameters.insert(parameters.end(), resources.begin(), resources.end());
- if (parameters.empty()) {
+ input_mapping->insert(input_mapping->end(), resources.begin(),
+ resources.end());
+ if (input_mapping->empty()) {
return Status::OK();
}
- std::vector<xla::Shape> arg_shapes;
- arg_shapes.reserve(parameters.size());
- input_mapping->resize(parameters.size());
- for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
- const XlaCompiler::Argument& arg = args[parameters[i]];
+ std::vector<xla::Shape> arg_shapes(input_mapping->size());
+ for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
// Computes the shapes of non-constant arguments.
- arg_shapes.push_back(arg.shape);
- (*input_mapping)[i] = parameters[i];
+ TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument(
+ args[(*input_mapping)[i]], &arg_shapes[i]));
}
if (use_tuple_arg) {
@@ -354,13 +414,13 @@ Status BuildArguments(const Graph& graph,
}
// Build parameter handles for non-constant arguments.
- std::vector<xla::ComputationDataHandle> arg_handles(parameters.size());
+ std::vector<xla::ComputationDataHandle> arg_handles(input_mapping->size());
if (use_tuple_arg) {
xla::ComputationDataHandle tuple;
if (is_entry_computation) {
xla::OpSharding tuple_sharding;
tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
- for (int64 parameter : parameters) {
+ for (int64 parameter : *input_mapping) {
const int core = (*arg_cores)[parameter];
const int root_device = 0;
*tuple_sharding.add_tuple_shardings() =
@@ -373,16 +433,16 @@ Status BuildArguments(const Graph& graph,
} else {
tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple");
}
- for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
- const int core = (*arg_cores)[parameters[i]];
+ for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+ const int core = (*arg_cores)[input_mapping->at(i)];
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
arg_handles[i] = builder->GetTupleElement(tuple, i);
}
} else {
- for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
- const int core = (*arg_cores)[parameters[i]];
+ for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+ const int core = (*arg_cores)[input_mapping->at(i)];
xla::ScopedShardingAssignment assign_sharding(
builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
: xla::sharding_builder::AssignDevice(core));
@@ -393,19 +453,18 @@ Status BuildArguments(const Graph& graph,
// Fill in the handles in non-constant arguments.
VLOG(2) << "XLA computation inputs:";
- for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
- const XlaCompiler::Argument& arg = args[parameters[i]];
+ for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
+ const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
VLOG(2) << " XLA arg " << i
<< " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i])
- << " name: " << arg.name << " TF arg " << parameters[i];
- XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
+ << " name: " << arg.name << " TF arg " << input_mapping->at(i);
+ XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)];
switch (arg.kind) {
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
XlaResource* resource = arg_expression.resource();
- TF_RETURN_IF_ERROR(
- resource->SetFromPack(arg.tensor_array_gradients, arg_handles[i],
- /*reset_initial_values=*/true, builder));
+ TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
+ arg_handles[i], builder));
VLOG(2) << " resource: num_gradients: "
<< arg.tensor_array_gradients.size();
break;
@@ -486,6 +545,7 @@ Status BuildComputation(
XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = resource->arg_num();
update.type = resource->type();
+ update.shape = resource->shape();
update.modified = modified;
for (const auto& grad : resource->tensor_array_gradients()) {
update.tensor_array_gradients_accessed.insert(grad.first);
@@ -616,13 +676,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
++computation_output;
}
}
-
- for (std::vector<ResourceUpdate>::size_type i = 0;
- i < result->resource_updates.size(); ++i) {
- result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape(
- result->xla_output_shape, computation_output);
- ++computation_output;
- }
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 30d3c05ee9..b86c82c0ab 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -104,9 +104,17 @@ class XlaCompiler {
// is the type of the variable's value, not DT_RESOURCE.
DataType type;
- // The shape of the argument. If the argument is a resource, this is the
- // shape of the resource's value.
- xla::Shape shape;
+ // The shape of the argument. For:
+ // * a parameter: the shape of the parameter.
+ // * a constant: ignored; the shape given by constant_value is used
+ // instead.
+ // * an uninitialized resource: ignored. We don't yet know the shape of an
+ // uninitialized resource (otherwise we would have initialized it!)
+ // * an initialized variable: the shape of the variable's value.
+ // * an initialized TensorArray or Stack resource: the shape of an entry in
+ // the TensorArray/Stack. Note this is the size of a single entry, not the
+ // XLA data structure that represents the complete stack/array.
+ TensorShape shape;
// The value of the argument, if it is a compile-time constant. Must be a
// host-memory tensor.
@@ -175,8 +183,9 @@ class XlaCompiler {
int input_index;
// Type and shape of the tensor to be written back.
+ // The `shape` field has the same meaning as the Argument::shape field.
DataType type;
- xla::Shape shape;
+ TensorShape shape;
// Was the value of the variable modified by the computation?
// (Always true, unless `return_updated_values_for_all_resources` is true.)
@@ -266,11 +275,10 @@ class XlaCompiler {
const std::vector<Argument>& args,
CompilationResult* result);
- Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func,
- const std::vector<DataType>& types,
- const std::vector<TensorShape>& shapes,
- const std::vector<const XlaExpression*>& expressions,
- std::vector<Argument>* args);
+ // Returns the shape of the XLA parameter for an argument 'arg'.
+ // See the class comment for more details about the argument passing
+ // convention.
+ static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 7ebe4b75bc..65de4dbad7 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -191,10 +191,10 @@ TEST_F(XlaCompilerTest, Simple) {
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
- args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
@@ -242,10 +242,10 @@ TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
- args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[1].shape = TensorShape({2});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
@@ -281,7 +281,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) {
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
XlaCompiler::Options options = DefaultOptions();
XlaCompiler compiler(options);
@@ -373,7 +373,7 @@ TEST_F(XlaCompilerTest, ResourceManager) {
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
DummyResourceForTest* resource = new DummyResourceForTest();
@@ -420,7 +420,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
std::vector<XlaCompiler::Argument> args(1);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2});
+ args[0].shape = TensorShape({2});
// Compiles the graph.
auto options = DefaultOptions();
@@ -472,9 +472,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::S32, {2}),
- xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad2"};
@@ -540,9 +538,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::S32, {2}),
- xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
@@ -574,9 +570,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
args[0].resource_kind = XlaResource::kTensorArray;
args[0].initialized = true;
args[0].type = DT_INT32;
- args[0].shape = xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::S32, {2}),
- xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].shape = TensorShape({});
args[0].tensor_array_size = 2;
args[0].tensor_array_gradients = {"grad1"};
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index e8d17e2e0a..73878955e3 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -103,12 +103,14 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
-Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
- string name, DataType type,
- const xla::ComputationDataHandle& handle,
- XlaResource** resource) {
+Status XlaContext::CreateResource(
+ XlaResource::Kind kind, int arg_num, string name, DataType type,
+ TensorShape shape, const xla::ComputationDataHandle& handle,
+ int64 tensor_array_size, const std::set<string>& tensor_array_gradients,
+ XlaResource** resource) {
resources_.emplace_back(
- new XlaResource(kind, arg_num, std::move(name), type, handle));
+ new XlaResource(kind, arg_num, std::move(name), type, std::move(shape),
+ handle, tensor_array_size, tensor_array_gradients));
*resource = resources_.back().get();
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 1a7dafe8cd..fac0352ae8 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -71,11 +71,15 @@ class XlaContext : public ResourceBase {
Status AddConstRetval(int retval_index, DataType dtype,
const xla::Literal& literal);
- // Creates a resource with resource `kind` and initial type `type` and
- // value `handle`. `name` is a descriptive name for use in error messages.
+ // Creates a resource with resource `kind` and initial value `handle`. `name`
+ // is a descriptive name for use in error messages. See the `XlaResource`
+ // constructor for a description of the remaining arguments.
// Fails if the resource already exists.
Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
- DataType type, const xla::ComputationDataHandle& handle,
+ DataType type, TensorShape shape,
+ const xla::ComputationDataHandle& handle,
+ int64 tensor_array_size,
+ const std::set<string>& tensor_array_gradients,
XlaResource** resource);
const std::vector<std::unique_ptr<XlaResource>>& resources() {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index ee0aed672e..ee29158646 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -286,7 +286,8 @@ Status XlaOpKernelContext::ConstantInputList(
}
Status XlaOpKernelContext::ReadVariableInput(
- int index, xla::ComputationDataHandle* value) {
+ int index, DataType type, TensorShape* shape,
+ xla::ComputationDataHandle* value) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
XlaResource* variable = expression->resource();
@@ -296,7 +297,15 @@ Status XlaOpKernelContext::ReadVariableInput(
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name());
}
+ if (variable->type() != type) {
+ return errors::InvalidArgument(
+ "Type mismatch for read of variable ", variable->name(), ". Expected ",
+ DataTypeString(type), "; got ", DataTypeString(variable->type()));
+ }
*value = variable->value();
+ if (shape) {
+ *shape = variable->shape();
+ }
return Status::OK();
}
@@ -312,12 +321,7 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
variable->name());
}
*type = variable->type();
- auto shape_or_status = builder()->GetShape(variable->value());
- if (!shape_or_status.ok()) {
- return shape_or_status.status();
- }
- TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape));
+ *shape = variable->shape();
return Status::OK();
}
@@ -405,7 +409,17 @@ Status XlaOpKernelContext::AssignVariable(
XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
- return variable->SetValue(type, handle);
+
+ auto shape_or_status = builder()->GetShape(handle);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ TensorShape shape;
+ TF_RETURN_IF_ERROR(
+ XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
+
+ TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
+ return variable->SetValue(handle);
}
XlaCompiler* XlaOpKernelContext::compiler() const {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 6d3b6db228..e1fd0f55c6 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -164,11 +164,16 @@ class XlaOpKernelContext {
TensorShape* shape) const;
// Reads the current value of the resouce variable referred to by input
- // 'index'.
- Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
+ // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the
+ // variable. Returns an error if the variable has not been initialized, or if
+ // its type does not match `type`.
+ Status ReadVariableInput(int index, DataType type, TensorShape* shape,
+ xla::ComputationDataHandle* value);
// Assigns the value `handle` to the variable referenced by input
- // `input_index`. Marks the operator as having side effects.
+ // `input_index`. The variable must be of `type`. Returns an error if the
+ // variable has been initialized with a different type or with a
+ // different shape.
Status AssignVariable(int input_index, DataType type,
const xla::ComputationDataHandle& handle);
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index 9abac8bdaa..c2075b44b8 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -25,51 +25,99 @@ limitations under the License.
namespace tensorflow {
-XlaResource::XlaResource(Kind kind, int arg_num, string name,
- DataType initial_type,
- const xla::ComputationDataHandle& initial_value)
+XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type,
+ TensorShape shape,
+ const xla::ComputationDataHandle& initial_value,
+ int64 tensor_array_size,
+ const std::set<string>& tensor_array_gradients)
: kind_(kind),
arg_num_(arg_num),
name_(std::move(name)),
- type_(initial_type),
+ type_(type),
+ shape_(std::move(shape)),
value_(initial_value),
- initial_value_(initial_value) {
+ initial_value_(initial_value),
+ tensor_array_size_(tensor_array_size) {
CHECK(kind_ != kInvalid);
+
+ for (const string& gradient : tensor_array_gradients) {
+ tensor_array_gradients_[gradient].reset(
+ new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
+ /*name=*/strings::StrCat("TensorArrayGrad: ", name_),
+ type_, shape_, xla::ComputationDataHandle(),
+ tensor_array_size_, /*tensor_array_gradients=*/{}));
+ }
}
-Status XlaResource::SetValue(DataType type,
- const xla::ComputationDataHandle& value) {
- if (type_ == DT_INVALID && type == DT_INVALID) {
- return errors::InvalidArgument("Attempted to initialized resource ", name_,
- " to an invalid type");
+Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) {
+ if (type == DT_INVALID) {
+ return errors::InvalidArgument("Attempted to set type of resource '", name_,
+ "'' to an invalid type");
}
- if (type_ != DT_INVALID && type_ != type) {
+ if (initialized() && type_ != type) {
return errors::InvalidArgument("Type of resource ", name_,
" cannot be changed after initialization: "
"old type was ",
DataTypeString(type_), ", new type is ",
DataTypeString(type));
}
+ if (initialized() && shape_ != shape) {
+ return errors::InvalidArgument("Shape of resource ", name_,
+ " cannot be changed after initialization: "
+ "old shape was ",
+ shape_.DebugString(), ", new shape is ",
+ shape.DebugString());
+ }
type_ = type;
- value_ = value;
+ shape_ = shape;
return Status::OK();
}
-Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder,
- xla::Shape* shape) const {
- auto shape_or_status = builder->GetShape(value_);
- if (!shape_or_status.ok()) {
- return shape_or_status.status();
+Status XlaResource::SetValue(const xla::ComputationDataHandle& value) {
+ if (type_ == DT_INVALID) {
+ return errors::InvalidArgument(
+ "Resource '", name_,
+ "' must be initialized with a valid type before use.");
}
- *shape = *shape_or_status.ValueOrDie();
+ value_ = value;
return Status::OK();
}
-Status XlaResource::GetShape(xla::ComputationBuilder* builder,
- TensorShape* shape) const {
- xla::Shape xla_shape;
- TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape));
- TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape));
+Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) {
+ if (type_ == DT_INVALID) {
+ return errors::InvalidArgument(
+ "Resource '", name_,
+ "' must be initialized with a valid type before use.");
+ }
+ switch (kind_) {
+ case kVariable: {
+ value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
+ shape_.dim_sizes());
+ break;
+ }
+ case kTensorArray: {
+ TensorShape ta_shape;
+ ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AppendShape(shape_);
+ value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_),
+ ta_shape.dim_sizes());
+ break;
+ }
+ case kStack: {
+ TensorShape ta_shape;
+ ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AppendShape(shape_);
+ value_ =
+ builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_),
+ ta_shape.dim_sizes()),
+ builder->ConstantR0<int32>(0)});
+ break;
+ }
+
+ case kInvalid:
+ default:
+ LOG(FATAL) << "Invalid resource type";
+ }
return Status::OK();
}
@@ -82,36 +130,20 @@ Status XlaResource::GetOrCreateTensorArrayGradient(
std::unique_ptr<XlaResource>& gradient = tensor_array_gradients_[source];
if (!gradient) {
TensorShape ta_shape;
- TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape));
+ ta_shape.AddDim(tensor_array_size_);
+ ta_shape.AppendShape(shape_);
xla::ComputationDataHandle gradient_value = builder->Broadcast(
XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes());
gradient.reset(
new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1,
/*name=*/strings::StrCat("TensorArrayGrad: ", name_),
- type_, gradient_value));
- gradient->tensor_array_size_ = tensor_array_size_;
+ type_, shape_, gradient_value, tensor_array_size_,
+ /*tensor_array_gradients=*/{}));
}
*gradient_out = gradient.get();
return Status::OK();
}
-Status XlaResource::PackedShape(xla::ComputationBuilder* builder,
- xla::Shape* packed_shape) const {
- if (tensor_array_gradients_.empty()) {
- return GetXlaShape(builder, packed_shape);
- }
- TF_RET_CHECK(kind_ == kTensorArray);
- std::vector<xla::Shape> elem_shapes(1 + tensor_array_gradients_.size());
- int pos = 0;
- TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++]));
- for (const auto& gradient : tensor_array_gradients_) {
- TF_RETURN_IF_ERROR(
- gradient.second->GetXlaShape(builder, &elem_shapes[pos++]));
- }
- *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
- return Status::OK();
-}
-
Status XlaResource::Pack(xla::ComputationDataHandle* pack,
xla::ComputationBuilder* builder) const {
if (tensor_array_gradients_.empty()) {
@@ -130,27 +162,32 @@ Status XlaResource::Pack(xla::ComputationDataHandle* pack,
Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
const xla::ComputationDataHandle& pack,
- bool reset_initial_values,
xla::ComputationBuilder* builder) {
if (gradient_sources.empty()) {
+ if (!initialized()) {
+ initial_value_ = pack;
+ }
value_ = pack;
} else {
TF_RET_CHECK(kind_ == kTensorArray);
int pos = 0;
- value_ = builder->GetTupleElement(pack, pos++);
+ auto v = builder->GetTupleElement(pack, pos++);
+ if (!initialized()) {
+ initial_value_ = v;
+ }
+ value_ = v;
+
for (const auto& source : gradient_sources) {
XlaResource* gradient;
TF_RETURN_IF_ERROR(
GetOrCreateTensorArrayGradient(source, builder, &gradient));
- gradient->value_ = builder->GetTupleElement(pack, pos++);
- if (reset_initial_values) {
- gradient->initial_value_ = gradient->value_;
+ auto v = builder->GetTupleElement(pack, pos++);
+ if (!gradient->initialized()) {
+ gradient->initial_value_ = v;
}
+ gradient->value_ = v;
}
}
- if (reset_initial_values) {
- initial_value_ = value_;
- }
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h
index 6b46089e4f..1bb2c7274e 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.h
+++ b/tensorflow/compiler/tf2xla/xla_resource.h
@@ -36,8 +36,11 @@ class XlaResource {
kStack,
};
- XlaResource(Kind kind, int arg_num, string name, DataType initial_type,
- const xla::ComputationDataHandle& initial_value);
+ XlaResource(Kind kind, int arg_num, string name, DataType type,
+ TensorShape shape,
+ const xla::ComputationDataHandle& initial_value,
+ int64 tensor_array_size,
+ const std::set<string>& tensor_array_gradients);
XlaResource(const XlaResource&) = delete;
XlaResource(XlaResource&&) = delete;
@@ -60,6 +63,12 @@ class XlaResource {
// a resource is first initialized we do not yet know its type, so we keep
// track of its type dynamically.
DataType type() const { return type_; }
+
+ // Shape of the resource. For an uninitialized resource, this is ignored.
+ // For a Variable, this is the shape of the value. For a TensorArray or Stack
+ // this is the shape of each entry in the TensorArray/Stack.
+ const TensorShape& shape() const { return shape_; }
+
const xla::ComputationDataHandle& value() const { return value_; }
// Value of the resource at computation entry. Used to detect which
@@ -68,17 +77,19 @@ class XlaResource {
return initial_value_;
}
+ // A variable is initialized if it has a value.
bool initialized() const { return value_.handle() > 0; }
- // Sets the current type/value of the resource.
- Status SetValue(DataType type, const xla::ComputationDataHandle& value);
+ // Sets the type and shape of the resource. The type and shape of a resource
+ // must not change once the variable has been initialized.
+ Status SetTypeAndShape(DataType type, const TensorShape& shape);
- // Returns the shape of the resource as an xla::Shape.
- Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const;
+ // Sets the current value of the resource. Returns an error if the type is not
+ // set to a valid value.
+ Status SetValue(const xla::ComputationDataHandle& value);
- // Returns the shape of the resource as an TensorShape. Fails if the shape is
- // not representable as a TensorShape.
- Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const;
+ // Sets the current value of the resource to an all-zero value.
+ Status SetZeroValue(xla::ComputationBuilder* builder);
// Looks up the gradient for `source`, or creates it if it does not already
// exist. The call target must be an initialized TensorArray resource. A
@@ -96,10 +107,6 @@ class XlaResource {
Status Pack(xla::ComputationDataHandle* pack,
xla::ComputationBuilder* builder) const;
- // Returns the shape of the `pack` value computed by `Pack()`.
- Status PackedShape(xla::ComputationBuilder* builder,
- xla::Shape* packed_shape) const;
-
// Updates the resource with values from `pack`. If `gradient_sources` is
// non-empty, treats `pack` as a tuple that represents a TensorArray and
// its gradients, and unpacks and updates the gradient resources.
@@ -108,14 +115,14 @@ class XlaResource {
// Opposite of Pack().
Status SetFromPack(const std::set<string>& gradient_sources,
const xla::ComputationDataHandle& pack,
- bool reset_initial_values,
xla::ComputationBuilder* builder);
- // TensorArray-specific fields
+ // TensorArray and Stack specific fields
// 'tensor_array_size' stores the expected size of the TensorArray or Stack.
// We need to store this since sometimes TensorArrays must be initialized
// lazily since we do not know the element shape at construction time.
+ // Used by both TensorArrays and Stacks.
int64 tensor_array_size() const { return tensor_array_size_; }
void set_tensor_array_size(int64 size) { tensor_array_size_ = size; }
@@ -136,6 +143,7 @@ class XlaResource {
const string name_;
DataType type_;
+ TensorShape shape_;
xla::ComputationDataHandle value_;
xla::ComputationDataHandle initial_value_;