aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-06-16 10:32:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-16 10:36:33 -0700
commita66de1eca225bc95e7972974a7089d84df8a8055 (patch)
tree7a332e8fd1f1656a0ef0caacbfff0851ab13967f
parentd06121593035687176d8b660b83bab568853deff (diff)
[TF:XLA] Refactor handling of Resources (Variables and TensorArrays) in the XLA bridge.
* Rename "Variable" to "Resource" in many places where non-Variable resources might be used. * Add kTensorArray to the XlaCompiler::Argument enum. Remove kUninitializedVariable and make "initialized" a separate boolean field. * Add a kind field to XlaResource. Add checks that Variables are not used where TensorArrays are expected, and vice-versa. * Clean ups to the TensorArray operators. PiperOrigin-RevId: 159244478
-rw-r--r--tensorflow/compiler/jit/kernels/xla_device_launch_op.cc4
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.cc5
-rw-r--r--tensorflow/compiler/tests/tensor_array_ops_test.py2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/arg_op.cc25
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc205
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc89
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.h35
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc73
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h39
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc20
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h26
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc39
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h20
13 files changed, 327 insertions, 255 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc
index 29c5ff7242..e786fdb16d 100644
--- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc
@@ -202,8 +202,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) {
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
- for (int i = 0; i < kernel->variable_updates.size(); ++i) {
- const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i];
+ for (int i = 0; i < kernel->resource_updates.size(); ++i) {
+ const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
OP_REQUIRES(ctx,
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index 63ca77f9a9..2325217b97 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -182,17 +182,18 @@ Status BuildArguments(int num_constant_args,
XlaCompiler::Argument& arg = (*args)[input_num];
arg.name = variable_args[variable_id].name;
+ arg.kind = XlaCompiler::Argument::kVariable;
if (variable_args[variable_id].present) {
const Tensor& value = variable_args[variable_id].value;
- arg.kind = XlaCompiler::Argument::kVariable;
arg.type = value.dtype();
arg.shape = value.shape();
+ arg.initialized = true;
} else {
// The values of uninitialized variables are not passed as inputs, since
// they are meaningless. However, it is legal to assign to a resource
// variable for the first time inside the XLA computation, so we do permit
// uninitialized variables.
- arg.kind = XlaCompiler::Argument::kUninitializedVariable;
+ arg.initialized = false;
arg.type = DT_INVALID;
arg.shape = TensorShape();
}
diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py
index 00a7358130..b3067be51d 100644
--- a/tensorflow/compiler/tests/tensor_array_ops_test.py
+++ b/tensorflow/compiler/tests/tensor_array_ops_test.py
@@ -335,7 +335,7 @@ class TensorArrayTest(xla_test.XLATestCase):
r0_bad = gen_data_flow_ops._tensor_array_read_v3(
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
with self.assertRaisesOpError(
- "TensorArray dtype is float but Op requested dtype double."):
+ "TensorArray dtype is float but op has dtype double."):
r0_bad.eval()
# Test reading from a different index than the one we wrote to
diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
index 620fc84437..6ad72c6219 100644
--- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
@@ -51,13 +51,26 @@ class ArgOp : public XlaOpKernel {
XlaContext& xc = XlaContext::Get(ctx);
const XlaContext::Argument& arg = xc.args()[index_];
- if (arg.is_variable) {
+ if (arg.is_resource) {
+ XlaResource::Kind kind;
+ switch (arg.kind) {
+ case XlaCompiler::Argument::kVariable:
+ kind = XlaResource::kVariable;
+ break;
+ case XlaCompiler::Argument::kTensorArray:
+ kind = XlaResource::kTensorArray;
+ break;
+ default:
+ CHECK(false);
+ }
+
// TODO(phawkins): this code assumes that variables do not alias.
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type,
- arg.value.handle, &var));
- var->tensor_array_size = arg.tensor_array_size;
- ctx->SetVariableOutput(0, var);
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx,
+ xc.CreateResource(kind, index_, arg.name, arg.value.type,
+ arg.value.handle, &resource));
+ resource->tensor_array_size = arg.tensor_array_size;
+ ctx->SetResourceOutput(0, resource);
} else if (arg.value.is_constant) {
ctx->SetConstantOutput(0, arg.value.constant_value);
} else {
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index c7510bf3d2..598b341002 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -41,32 +41,36 @@ namespace {
// Since the element shape is not always provided to the TensorArrayV3 operator,
// we must support lazily initialization of the TensorArray at the time of the
// first write.
-// If a TensorArray `var` has not been initialized, constructs storage for the
-// TensorArray with elements of `elem_shape`. For both initialized and
+// If a TensorArray `resource` has not been initialized, constructs storage for
+// the TensorArray with elements of `elem_shape`. For both initialized and
// uninitialized TensorArrays, checks that the tensor has a type compatible with
// 'dtype' and shape compatible with 'elem_shape'.
Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
- XlaVariable* var, DataType dtype,
+ XlaResource* resource, DataType dtype,
const TensorShape& elem_shape) {
- if (var->type != dtype) {
+ if (resource->kind != XlaResource::kTensorArray) {
+ return errors::InvalidArgument("Unexpected non-TensorArray resource");
+ }
+
+ if (resource->type != dtype) {
return errors::InvalidArgument(
- "TensorArray dtype is ", DataTypeString(var->type),
+ "TensorArray dtype is ", DataTypeString(resource->type),
" but op has dtype ", DataTypeString(dtype), ".");
}
- TF_RET_CHECK(var->tensor_array_size >= 0)
- << var->name << " size " << var->tensor_array_size;
+ TF_RET_CHECK(resource->tensor_array_size >= 0)
+ << resource->name << " size " << resource->tensor_array_size;
TensorShape ta_shape;
- ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AddDim(resource->tensor_array_size);
ta_shape.AppendShape(elem_shape);
- if (var->value.handle() == 0) {
+ if (resource->value.handle() == 0) {
// TensorArray has not been initialized.
- xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type);
- var->value = builder->Broadcast(zero, ta_shape.dim_sizes());
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type);
+ resource->value = builder->Broadcast(zero, ta_shape.dim_sizes());
} else {
// Checks the elem_shape matches the TensorArray shape.
- auto shape_or_status = builder->GetShape(var->value);
+ auto shape_or_status = builder->GetShape(resource->value);
if (!shape_or_status.ok()) {
return shape_or_status.status();
}
@@ -80,6 +84,44 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder,
return Status::OK();
}
+// Checks that the TensorArray 'resource' has been initialized, and has type
+// 'dtype'. Sets 'shape' to the shape
+Status CheckTensorArrayIsInitialized(const string& op_name,
+ const XlaResource* resource,
+ DataType dtype) {
+ if (resource->kind != XlaResource::kTensorArray) {
+ return errors::InvalidArgument(
+ "Unexpected non-TensorArray resource passed "
+ "to ",
+ op_name);
+ }
+ if (resource->value.handle() == 0) {
+ return errors::InvalidArgument("Uninitialized TensorArray passed to ",
+ op_name);
+ }
+ if (resource->type != dtype) {
+ return errors::InvalidArgument(
+ "TensorArray dtype is ", DataTypeString(resource->type),
+ " but op has dtype ", DataTypeString(dtype), ".");
+ }
+
+ return Status::OK();
+}
+
+Status GetTensorArrayShape(const XlaResource* resource,
+ xla::ComputationBuilder* builder,
+ TensorShape* shape) {
+ auto shape_or_status = builder->GetShape(resource->value);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ *shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie());
+ if (shape->dims() < 1) {
+ return errors::InvalidArgument("TensorArray rank must be >= 1");
+ }
+ return Status::OK();
+}
+
// Pads 'x' with 'count' zero indices. 'x' must have 1 element.
xla::ComputationDataHandle PadIndexWithZeros(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
@@ -125,7 +167,6 @@ class TensorArrayOp : public XlaOpKernel {
errors::InvalidArgument("TensorArray size must be >= 0"));
xla::ComputationBuilder* b = ctx->builder();
- b->set_die_immediately_on_error(true);
// Initializes the TensorArray value if we know the element shape.
// Otherwise, defer initialization to the first write.
@@ -141,12 +182,13 @@ class TensorArrayOp : public XlaOpKernel {
}
XlaContext& xc = XlaContext::Get(ctx);
- XlaVariable* var;
+ XlaResource* var;
string name = strings::StrCat("TensorArray: ", tensor_array_name_);
- OP_REQUIRES_OK(ctx,
- xc.CreateVariable(-1, std::move(name), dtype_, value, &var));
+ OP_REQUIRES_OK(
+ ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
+ dtype_, value, &var));
var->tensor_array_size = size;
- ctx->SetVariableOutput(0, var);
+ ctx->SetResourceOutput(0, var);
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
}
@@ -173,11 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel {
// Initializes the TensorArray, if the element shape was not known at
// construction time.
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+ OP_REQUIRES_OK(ctx,
+ MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
- xla::ComputationDataHandle ta = var->value;
+ xla::ComputationDataHandle ta = resource->value;
xla::ComputationDataHandle index = ctx->Input(1);
xla::ComputationDataHandle value = ctx->Input(2);
@@ -191,7 +234,7 @@ class TensorArrayWriteOp : public XlaOpKernel {
xla::ComputationDataHandle written =
DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices);
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written));
+ resource->value = written;
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@@ -210,20 +253,17 @@ class TensorArrayReadOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- DataType ta_type;
- TensorShape ta_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
- OP_REQUIRES(ctx, ta_type == dtype_,
- errors::InvalidArgument(
- "TensorArray dtype is ", DataTypeString(ta_type),
- " but Op requested dtype ", DataTypeString(dtype_), "."));
- OP_REQUIRES(ctx, ta_shape.dims() >= 1,
- errors::InvalidArgument("TensorArray rank must be >= 1"));
-
xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle ta;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+
+ OP_REQUIRES_OK(ctx,
+ CheckTensorArrayIsInitialized(name(), resource, dtype_));
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
+
+ xla::ComputationDataHandle ta = resource->value;
xla::ComputationDataHandle index = ctx->Input(1);
// start_indices of the DynamicSlice are [index, 0, 0, ..., 0].
@@ -255,13 +295,15 @@ class TensorArrayGatherOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- DataType ta_type;
+ xla::ComputationBuilder* b = ctx->builder();
+
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+
+ OP_REQUIRES_OK(ctx,
+ CheckTensorArrayIsInitialized(name(), resource, dtype_));
TensorShape ta_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
- OP_REQUIRES(ctx, ta_type == dtype_,
- errors::InvalidArgument("TensorArray type mismatch"));
- OP_REQUIRES(ctx, ta_shape.dims() >= 1,
- errors::InvalidArgument("TensorArray rank must be >= 1"));
+ OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
const TensorShape indices_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
@@ -269,10 +311,7 @@ class TensorArrayGatherOp : public XlaOpKernel {
const int num_indices = indices_shape.dim_size(0);
auto indices = ctx->Input(1);
- xla::ComputationBuilder* b = ctx->builder();
-
- xla::ComputationDataHandle ta;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ xla::ComputationDataHandle ta = resource->value;
// For each index in `indices`, add the corresponding slice to `slices`.
std::vector<xla::ComputationDataHandle> slices(num_indices);
@@ -320,11 +359,12 @@ class TensorArrayScatterOp : public XlaOpKernel {
const TensorShape value_shape = ctx->InputShape(2);
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
TensorShape elem_shape = value_shape;
elem_shape.RemoveDim(0);
- OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
+ OP_REQUIRES_OK(ctx,
+ MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
const TensorShape indices_shape = ctx->InputShape(1);
OP_REQUIRES(ctx, indices_shape.dims() >= 1,
@@ -332,7 +372,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
const int num_indices = indices_shape.dim_size(0);
const xla::ComputationDataHandle indices = ctx->Input(1);
- xla::ComputationDataHandle ta = var->value;
+ xla::ComputationDataHandle ta = resource->value;
const xla::ComputationDataHandle value = ctx->Input(2);
auto slice_dims = value_shape.dim_sizes();
@@ -355,7 +395,7 @@ class TensorArrayScatterOp : public XlaOpKernel {
ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices);
}
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+ resource->value = ta;
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@@ -374,18 +414,17 @@ class TensorArrayConcatOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
- DataType ta_type;
- TensorShape ta_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
- OP_REQUIRES(ctx, ta_type == dtype_,
- errors::InvalidArgument("TensorArray type mismatch"));
- OP_REQUIRES(ctx, ta_shape.dims() >= 1,
- errors::InvalidArgument("TensorArray rank must be >= 1"));
-
xla::ComputationBuilder* b = ctx->builder();
- xla::ComputationDataHandle ta;
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+
+ OP_REQUIRES_OK(ctx,
+ CheckTensorArrayIsInitialized(name(), resource, dtype_));
+ TensorShape ta_shape;
+ OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
+
+ xla::ComputationDataHandle ta = resource->value;
auto ta_dims = ta_shape.dim_sizes();
std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end());
@@ -436,19 +475,20 @@ class TensorArraySplitOp : public XlaOpKernel {
elem_shape.set_dim(0, length);
xla::ComputationBuilder* b = ctx->builder();
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
- OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape));
- xla::ComputationDataHandle ta = var->value;
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
+ OP_REQUIRES_OK(ctx,
+ MaybeInitializeTensorArray(b, resource, dtype_, elem_shape));
+ xla::ComputationDataHandle ta = resource->value;
TensorShape ta_shape;
- ta_shape.AddDim(var->tensor_array_size);
+ ta_shape.AddDim(resource->tensor_array_size);
ta_shape.AppendShape(elem_shape);
- OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size,
+ OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size,
errors::InvalidArgument(
"TensorArray's size is not equal to the size of lengths (",
- lengths.size(), " vs. ", var->tensor_array_size, ")"));
+ lengths.size(), " vs. ", resource->tensor_array_size, ")"));
const xla::ComputationDataHandle value = ctx->Input(1);
@@ -457,8 +497,7 @@ class TensorArraySplitOp : public XlaOpKernel {
value_shape.DebugString(), " vs. ",
ta_shape.DebugString()));
- ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta));
+ resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes()));
ctx->SetConstantOutput(0, Tensor(DT_FLOAT));
}
@@ -476,8 +515,8 @@ class TensorArraySizeOp : public XlaOpKernel {
explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ XlaResource* var;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var));
Tensor size_tensor(DT_INT32, {});
size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size);
ctx->SetConstantOutput(0, size_tensor);
@@ -498,31 +537,31 @@ class TensorArrayGradOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
- XlaVariable* var;
- OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
- DataType ta_type;
+ OP_REQUIRES_OK(
+ ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type));
TensorShape ta_shape;
- OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape));
- OP_REQUIRES(ctx, ta_shape.dims() >= 1,
- errors::InvalidArgument("TensorArray rank must be >= 1"));
+ OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape));
// Finds or looks up the corresponding gradient TensorArray, which stores
// gradients computed during backpropagation.
- XlaVariable*& gradient = var->tensor_array_gradient[source_];
+ XlaResource*& gradient = resource->tensor_array_gradient[source_];
if (!gradient) {
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type);
+ xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type);
xla::ComputationDataHandle value =
b->Broadcast(zero, ta_shape.dim_sizes());
XlaContext& xc = XlaContext::Get(ctx);
- string name = strings::StrCat("TensorArrayGrad: ", var->name);
- OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type,
- value, &gradient));
- gradient->tensor_array_size = var->tensor_array_size;
+ string name = strings::StrCat("TensorArrayGrad: ", resource->name);
+ OP_REQUIRES_OK(
+ ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
+ resource->type, value, &gradient));
+ gradient->tensor_array_size = resource->tensor_array_size;
}
- ctx->SetVariableOutput(0, gradient);
+ ctx->SetResourceOutput(0, gradient);
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
}
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index ae9a358e22..0caa9c5f37 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -42,29 +42,38 @@ Status MakeXlaCompilerArgumentsFromInputs(
<< " shape: " << ctx->InputShape(i).DebugString();
XlaCompiler::Argument& arg = (*args)[i];
DataType type = ctx->input_type(i);
- // When reading a variable input, use the type and shape of the variable's
+ // When reading a resource input, use the type and shape of the resource's
// current value.
if (type == DT_RESOURCE) {
- XlaVariable* var;
- TF_RETURN_IF_ERROR(ctx->GetVariableInput(i, &var));
-
- bool initialized = var->value.handle() > 0;
- if (initialized) {
- arg.kind = XlaCompiler::Argument::kVariable;
- TF_RETURN_IF_ERROR(
- ctx->GetVariableTypeAndShape(i, &arg.type, &arg.shape));
+ XlaResource* resource;
+ TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource));
+
+ arg.initialized = resource->value.handle() > 0;
+ switch (resource->kind) {
+ case XlaResource::kVariable:
+ arg.kind = XlaCompiler::Argument::kVariable;
+ break;
+ case XlaResource::kTensorArray:
+ arg.kind = XlaCompiler::Argument::kTensorArray;
+ break;
+ case XlaResource::kInvalid:
+ CHECK(false);
+ }
+ arg.type = resource->type;
+ if (arg.initialized) {
+ auto shape = ctx->builder()->GetShape(resource->value);
+ TF_RETURN_IF_ERROR(shape.status());
+ arg.shape = XLAShapeToTensorShape(*shape.ValueOrDie());
} else {
- arg.kind = XlaCompiler::Argument::kUninitializedVariable;
- arg.type = var->type;
*has_uninitialized_vars = true;
}
- arg.tensor_array_size = var->tensor_array_size;
- arg.name = var->name;
+ arg.tensor_array_size = resource->tensor_array_size;
+ arg.name = resource->name;
// TODO(phawkins): propagate TensorArray gradients into loops.
- VLOG(2) << " variable " << var->name
+ VLOG(2) << " resource " << resource->name
<< " type: " << DataTypeString(arg.type)
<< " shape: " << arg.shape.DebugString()
- << " initialized: " << initialized;
+ << " initialized: " << arg.initialized;
} else {
arg.kind = XlaCompiler::Argument::kParameter;
@@ -86,7 +95,7 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
}
void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
- VLOG(1) << "WhileOp::Compute";
+ VLOG(1) << "WhileOp::Compile";
std::vector<XlaCompiler::Argument> arguments;
bool has_uninitialized_vars;
@@ -100,16 +109,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
VLOG(1) << "Compiling body";
- // All resource variables that are inputs to the loop's body must also be
+ // All resource that are inputs to the loop's body must also be
// present as loop body outputs; the signature of the loop's input and
// output must match. We ensure this by asking the compiler to include the
- // current values of all variables, even if they haven't been updated by the
+ // current values of all resources, even if they haven't been updated by the
// computation.
// TODO(phawkins): consider adding loop-invariant inputs to XLA's While()
// operator.
XlaCompiler::CompileOptions body_options;
body_options.use_tuple_arg = use_tuple_arg;
- body_options.return_updated_values_for_all_variables = true;
+ body_options.return_updated_values_for_all_resources = true;
XlaCompiler::CompilationResult body;
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
@@ -118,26 +127,28 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
// we may not know the shape of a TensorArray if it is first written inside
// the loop. Ideally we would require the user to provide a static shape,
// but this is not always easy.
- // So if uninitialized variables are used by the loop body, we compile the
+ // So if uninitialized resource are used by the loop body, we compile the
// body function twice:
- // 1) once with uninitialized variable inputs. We discard the computation
- // but we assume variable shapes reach a fixpoint after one iteration.
- // So we can use the output shapes of the variables as the "true" shapes.
+ // 1) once with uninitialized resource inputs. We discard the computation
+ // but we assume resource shapes reach a fixpoint after one iteration.
+ // So we can use the output shapes of the resource as the "true" shapes.
// 2) again with the "correct" input shapes determined by (1).
if (has_uninitialized_vars) {
- // Initializes any uninitialized variables with zero values of the
+ // Initializes any uninitialized resource with zero values of the
// shape determined by the first compilation.
- for (int i = 0; i < body.variable_updates.size(); ++i) {
- const XlaCompiler::VariableUpdate& update = body.variable_updates[i];
+ for (int i = 0; i < body.resource_updates.size(); ++i) {
+ const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
XlaCompiler::Argument& arg = arguments[update.input_index];
- if (arg.kind == XlaCompiler::Argument::kUninitializedVariable) {
- arg.kind = XlaCompiler::Argument::kVariable;
+ if (!arg.initialized) {
+ arg.initialized = true;
arg.shape = update.shape;
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx,
+ ctx->GetResourceInput(update.input_index, &resource));
+
xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, arg.type);
- auto value = builder->Broadcast(zero, update.shape.dim_sizes());
- OP_REQUIRES_OK(
- ctx, ctx->AssignVariable(update.input_index, arg.type, value));
+ resource->value = builder->Broadcast(zero, update.shape.dim_sizes());
}
}
// Recompile the body with the "correct" shapes.
@@ -191,7 +202,9 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
for (int i = 0; i < num_inputs; ++i) {
int input_num = body.input_mapping[i];
if (ctx->input_type(input_num) == DT_RESOURCE) {
- OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(input_num, &inputs[i]));
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
+ inputs[i] = resource->value;
} else {
inputs[i] = ctx->Input(i);
}
@@ -225,16 +238,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
}
// Updates the values of any resource variables modified by the loop.
- for (int i = 0; i < body.variable_updates.size(); ++i) {
- const XlaCompiler::VariableUpdate& update = body.variable_updates[i];
+ for (int i = 0; i < body.resource_updates.size(); ++i) {
+ const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
if (update.modified) {
int pos = body.outputs.size() + i;
- OP_REQUIRES_OK(ctx, ctx->AssignVariable(update.input_index, update.type,
- get_loop_output(pos)));
+ resource->value = get_loop_output(pos);
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
- << " name: " << ctx->VariableDebugString(update.input_index)
- << " modified: " << update.modified
+ << " name: " << resource->name << " modified: " << update.modified
<< " type: " << DataTypeString(update.type)
<< " shape: " << update.shape.DebugString();
// Copies the identity of the resource variable from input to output
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h
index 75630bee39..e4f43f1950 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.h
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h
@@ -64,26 +64,35 @@ class XlaCompilationDevice : public LocalDevice {
std::unique_ptr<XlaCompilationAllocator> allocator_;
};
-struct XlaVariable {
- // If this variable is visible externally, what was its argument number?
+// Represents a resource, such as a Variable or TensorArray.
+struct XlaResource {
+ enum Kind {
+ kInvalid,
+ kVariable,
+ kTensorArray,
+ };
+
+ Kind kind = kInvalid;
+
+ // If this resource is visible externally, what was its argument number?
int arg_num = -1;
- // A descriptive name for the variable, used in error messages.
+ // A descriptive name for the resource, used in error messages.
string name;
- // Current type and value of the variable. Uninitialized variables are
+ // Current type and value of the resource. Uninitialized resources are
// represented by a default (zero) handle and type DT_INVALID.
- // While the type of a variable is notionally fixed during execution, when
- // a variable is first initialized we do not yet know its type, so we keep
+ // While the type of a resource is notionally fixed during execution, when
+ // a resource is first initialized we do not yet know its type, so we keep
// track of its type dynamically.
DataType type = DT_INVALID;
xla::ComputationDataHandle value;
- // Value of the variable at computation entry. Used to detect which
+ // Value of the resource at computation entry. Used to detect which
// variables have new values that need to be written back.
xla::ComputationDataHandle initial_value;
- // We treat TensorArrays as a Variable with some extra metadata.
+ // TensorArray-specific fields
// 'tensor_array_size' stores the expected size of the TensorArray. We need
// to store this since sometimes TensorArrays must be initialized lazily since
@@ -91,10 +100,10 @@ struct XlaVariable {
int64 tensor_array_size = -1;
// 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
- // to an XlaVariable containing the gradient TensorArrays. We store a pointer
+ // to an XlaResource containing the gradient TensorArrays. We store a pointer
// here since there should only be one gradient TensorArray per 'source'
// string, irrespective of the number of calls to TensorArrayGrad.
- std::unordered_map<string, XlaVariable*> tensor_array_gradient;
+ std::unordered_map<string, XlaResource*> tensor_array_gradient;
};
// A XlaExpression wraps an XLA computation. Each Tensor on an
@@ -115,8 +124,8 @@ class XlaExpression {
bool has_constant_value() const { return has_constant_value_; }
const Tensor& constant_value() const { return constant_value_; }
- void set_variable(XlaVariable* variable) { variable_ = variable; }
- XlaVariable* variable() const { return variable_; }
+ void set_resource(XlaResource* resource) { resource_ = resource; }
+ XlaResource* resource() const { return resource_; }
private:
// The XLA handle of the expression's computation.
@@ -128,7 +137,7 @@ class XlaExpression {
bool has_constant_value_ = false;
Tensor constant_value_;
- XlaVariable* variable_ = nullptr; // Not owned.
+ XlaResource* resource_ = nullptr; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression);
};
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index d7c92a710c..50b384997a 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -251,35 +251,36 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
std::vector<xla::Shape>* input_shapes) {
context_args->resize(args.size());
- // Argument numbers of arguments and variables that are to be passed to the
+ // Argument numbers of arguments and resources that are to be passed to the
// XLA computation as runtime parameters.
- std::vector<int> parameters, variables;
+ std::vector<int> parameters, resources;
parameters.reserve(args.size());
- variables.reserve(args.size());
+ resources.reserve(args.size());
for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size();
++i) {
XlaContext::Argument& context_arg = (*context_args)[i];
+ context_arg.kind = args[i].kind;
context_arg.name = args[i].name;
context_arg.value.constant_value = args[i].constant_value;
context_arg.value.type = args[i].type;
switch (args[i].kind) {
case XlaCompiler::Argument::kVariable:
- variables.push_back(i);
- context_arg.is_variable = true;
- context_arg.value.is_constant = false;
+ case XlaCompiler::Argument::kTensorArray:
+ context_arg.is_resource = true;
+ if (args[i].initialized) {
+ resources.push_back(i);
+ context_arg.value.is_constant = false;
+ } else {
+ context_arg.value.is_constant = true;
+ }
context_arg.tensor_array_size = args[i].tensor_array_size;
break;
case XlaCompiler::Argument::kParameter:
parameters.push_back(i);
context_arg.value.is_constant = false;
break;
- case XlaCompiler::Argument::kUninitializedVariable:
- context_arg.is_variable = true;
- context_arg.value.is_constant = true;
- context_arg.tensor_array_size = args[i].tensor_array_size;
- break;
case XlaCompiler::Argument::kConstant:
context_arg.value.is_constant = true;
break;
@@ -290,7 +291,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// Append parameters containing variable values after the other runtime
// parameters.
- parameters.insert(parameters.end(), variables.begin(), variables.end());
+ parameters.insert(parameters.end(), resources.begin(), resources.end());
if (parameters.empty()) {
return Status::OK();
}
@@ -331,22 +332,22 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// variable states, generated by the symbolic evaluation.
// If `has_side_effects` is true, the computation has side effects and should be
// built even if it has no outputs.
-// If `return_updated_values_for_all_variables` is true, all variables will be
-// included in `variable_updates`, regardless of whether their value changed.
+// If `return_updated_values_for_all_resources` is true, all resources will be
+// included in `resource_updates`, regardless of whether their value changed.
// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
-// Sets `*variable_updates` to a description of variables whose values are
+// Sets `*resource_updates` to a description of resources whose values are
// written by the computation; the variable writes are the last
-// `variable_updates.size()` return values from the computation. Each entry in
-// `variable_updates` is a (input_index, type) pair, where `input_index` is the
+// `resource_updates.size()` return values from the computation. Each entry in
+// `resource_updates` is a (input_index, type) pair, where `input_index` is the
// index of a resource variable argument to the computation, and `type` is the
// type of the final output.
Status BuildComputation(
const std::vector<XlaContext::HandleOrConstant>& retvals,
- const std::vector<std::unique_ptr<XlaVariable>>& variables,
- bool has_side_effects, bool return_updated_values_for_all_variables,
+ const std::vector<std::unique_ptr<XlaResource>>& resources,
+ bool has_side_effects, bool return_updated_values_for_all_resources,
xla::ComputationBuilder* builder, xla::Computation* computation,
int* num_nonconst_outputs,
- std::vector<XlaCompiler::VariableUpdate>* variable_updates) {
+ std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::ComputationDataHandle> elems;
elems.reserve(retvals.size());
for (const XlaContext::HandleOrConstant& retval : retvals) {
@@ -356,24 +357,24 @@ Status BuildComputation(
}
*num_nonconst_outputs = elems.size();
- // Add return values for variables whose values have changed.
- std::vector<const XlaVariable*> arg_vars;
- arg_vars.reserve(variables.size());
- for (const auto& var : variables) {
+ // Add return values for resources whose values have changed.
+ std::vector<const XlaResource*> arg_vars;
+ arg_vars.reserve(resources.size());
+ for (const auto& var : resources) {
if (var->arg_num >= 0) {
arg_vars.push_back(var.get());
}
}
std::sort(arg_vars.begin(), arg_vars.end(),
- [](const XlaVariable* a, const XlaVariable* b) {
+ [](const XlaResource* a, const XlaResource* b) {
return a->arg_num < b->arg_num;
});
- for (const XlaVariable* var : arg_vars) {
+ for (const XlaResource* var : arg_vars) {
bool modified = var->value.handle() != var->initial_value.handle();
- if (return_updated_values_for_all_variables || modified) {
- variable_updates->emplace_back();
- XlaCompiler::VariableUpdate& update = variable_updates->back();
+ if (return_updated_values_for_all_resources || modified) {
+ resource_updates->emplace_back();
+ XlaCompiler::ResourceUpdate& update = resource_updates->back();
update.input_index = var->arg_num;
update.type = var->type;
update.modified = modified;
@@ -439,10 +440,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_nonconst_outputs;
result->computation = std::make_shared<xla::Computation>();
TF_RETURN_IF_ERROR(BuildComputation(
- context->retvals(), context->variables(), context->has_side_effects(),
- options.return_updated_values_for_all_variables, &builder,
+ context->retvals(), context->resources(), context->has_side_effects(),
+ options.return_updated_values_for_all_resources, &builder,
result->computation.get(), &num_nonconst_outputs,
- &result->variable_updates));
+ &result->resource_updates));
result->requires_runtime_context = context->has_context_parameter();
@@ -517,15 +518,15 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
}
}
- for (std::vector<VariableUpdate>::size_type i = 0;
- i < result->variable_updates.size(); ++i) {
+ for (std::vector<ResourceUpdate>::size_type i = 0;
+ i < result->resource_updates.size(); ++i) {
if (num_computation_outputs > 1) {
- result->variable_updates[i].shape =
+ result->resource_updates[i].shape =
XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape(
result->xla_output_shape, computation_output));
} else {
CHECK_EQ(0, computation_output);
- result->variable_updates[i].shape =
+ result->resource_updates[i].shape =
XLAShapeToTensorShape(result->xla_output_shape);
}
++computation_output;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 6b9bf4159d..58e42c3474 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -85,14 +85,14 @@ class XlaCompiler {
// Argument is a compile-time constant. No associated runtime parameter.
kConstant,
- // Argument is a variable that has not been initialized yet. No associated
- // runtime parameter.
- kUninitializedVariable,
-
- // Argument is a variable that already has a value set. Expects a runtime
- // parameter containing the current value.
+ // Argument is a variable resource. Has an associated runtime parameter
+ // iff `initialized` is true.
kVariable,
+ // Argument is a TensorArray resource. Has an associated runtime parameter
+ // iff `initialized` is true.
+ kTensorArray,
+
// Argument is a run-time parameter.
kParameter,
};
@@ -114,8 +114,11 @@ class XlaCompiler {
// The name of this argument, used for debugging.
string name;
- // For a kVariable or kUninitializedVariable corresponding to a TensorArray,
- // what is the tensor array's declared size?
+ // For a kVariable or kTensorArray, has this resource been initialized?
+ bool initialized = false;
+
+ // For a kTensorArray, what is the array's declared size? (Used for lazy
+ // initialization.)
int64 tensor_array_size = -1;
bool operator==(const Argument& other) const;
@@ -133,7 +136,7 @@ class XlaCompiler {
};
// Describes a variable write side effect of the computation.
- struct VariableUpdate {
+ struct ResourceUpdate {
// Index of the input that contains the variable resource to write to.
int input_index;
@@ -142,14 +145,14 @@ class XlaCompiler {
TensorShape shape;
// Was the value of the variable modified by the computation?
- // (Always true, unless `return_updated_values_for_all_variables` is true.)
+ // (Always true, unless `return_updated_values_for_all_resources` is true.)
bool modified;
};
struct CompilationResult {
// Vector that maps from the parameters of the XLA computation to their
// original argument positions. To handle compile-time constant inputs and
- // variables, the parameters to the XLA computation may be a subset of the
+ // resources, the parameters to the XLA computation may be a subset of the
// original arguments, and are not necessarily in the same order.)
std::vector<int> input_mapping;
@@ -172,10 +175,10 @@ class XlaCompiler {
// containing both constant and non-constant results.
std::vector<OutputDescription> outputs;
- // Variables whose values were updated by the computation, ordered
- // by return value position. Variable updates follow the non-constant
+ // Resources whose values were updated by the computation, ordered
+ // by return value position. Resource updates follow the non-constant
// results in the outputs of XLA computation.
- std::vector<VariableUpdate> variable_updates;
+ std::vector<ResourceUpdate> resource_updates;
// The XLA computation built from the tensorflow subgraph. May be null
// if the output consists solely of compile-time constants.
@@ -229,12 +232,12 @@ class XlaCompiler {
// arguments; if false, each argument gets its own parameter.
bool use_tuple_arg = false;
- // If 'return_updated_values_for_all_variables' is true, then updated
- // values of all resource variables arguments will be included in the
- // 'variable_updates' of the computation, even if the variable was not
+ // If 'return_updated_values_for_all_resources' is true, then updated
+ // values of all resource resources arguments will be included in the
+ // 'resource_updates' of the computation, even if the resource was not
// modified by the computation. Used when compiling loop bodies to ensure
// the input and output signatures match.
- bool return_updated_values_for_all_variables = false;
+ bool return_updated_values_for_all_resources = false;
};
// Compiles a Tensorflow function `fn_name_attrs` into an XLA computation.
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 4440b53069..1a37d61944 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -129,16 +129,18 @@ void XlaContext::AddSideEffects() {
xla::ComputationBuilder* XlaContext::builder() { return builder_; }
-Status XlaContext::CreateVariable(int arg_num, string name, DataType type,
+Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num,
+ string name, DataType type,
const xla::ComputationDataHandle& handle,
- XlaVariable** variable) {
- variables_.emplace_back(new XlaVariable);
- *variable = variables_.back().get();
- XlaVariable& var = **variable;
- var.arg_num = arg_num;
- var.name = std::move(name);
- var.type = type;
- var.initial_value = var.value = handle;
+ XlaResource** resource) {
+ resources_.emplace_back(new XlaResource);
+ *resource = resources_.back().get();
+ XlaResource& r = **resource;
+ r.kind = kind;
+ r.arg_num = arg_num;
+ r.name = std::move(name);
+ r.type = type;
+ r.initial_value = r.value = handle;
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 3978baaf63..dbede52b5d 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -52,11 +52,13 @@ class XlaContext : public ResourceBase {
};
struct Argument {
- // Descriptive name for the variable, for use in error messages.
+ XlaCompiler::Argument::Kind kind;
+
+ // Descriptive name for the resource, for use in error messages.
string name;
- // Is this a variable?
- bool is_variable = false;
+ // Is this a resource?
+ bool is_resource = false;
HandleOrConstant value;
@@ -106,15 +108,15 @@ class XlaContext : public ResourceBase {
bool has_side_effects() const { return has_side_effects_; }
- // Creates a variable with variable `variable_id` and initial type `type` and
+ // Creates a resource with resource `kind` and initial type `type` and
// value `handle`. `name` is a descriptive name for use in error messages.
- // Fails if the variable already exists.
- Status CreateVariable(int arg_num, string name, DataType type,
- const xla::ComputationDataHandle& handle,
- XlaVariable** variable);
+ // Fails if the resource already exists.
+ Status CreateResource(XlaResource::Kind kind, int arg_num, string name,
+ DataType type, const xla::ComputationDataHandle& handle,
+ XlaResource** resource);
- const std::vector<std::unique_ptr<XlaVariable>>& variables() {
- return variables_;
+ const std::vector<std::unique_ptr<XlaResource>>& resources() {
+ return resources_;
}
// Get an XLA lambda to compute Max. This is cached in the
@@ -166,8 +168,8 @@ class XlaContext : public ResourceBase {
// Does the computation have side effects, i.e., Send() calls?
bool has_side_effects_ = false;
- // Holds ownership of variables. The variables are not ordered.
- std::vector<std::unique_ptr<XlaVariable>> variables_;
+ // Holds ownership of resources. The resources are not ordered.
+ std::vector<std::unique_ptr<XlaResource>> resources_;
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::Computation>;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 3272b1efa1..edb7e2a563 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -39,7 +39,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
const XlaExpression* expression =
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
CHECK(expression->handle().handle() != 0 ||
- expression->variable() != nullptr);
+ expression->resource() != nullptr);
VLOG(1) << "Fetched T" << expression->handle().handle();
return expression;
}
@@ -252,8 +252,9 @@ Status XlaOpKernelContext::ReadVariableInput(
int index, xla::ComputationDataHandle* value) {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
- XlaVariable* variable = expression->variable();
+ XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
+ TF_RET_CHECK(variable->kind == XlaResource::kVariable);
if (variable->value.handle() == 0) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name);
@@ -262,22 +263,13 @@ Status XlaOpKernelContext::ReadVariableInput(
return Status::OK();
}
-string XlaOpKernelContext::VariableDebugString(int index) {
- const Tensor& tensor = context_->input(index);
- const XlaExpression* expression = CastExpressionFromTensor(tensor);
- XlaVariable* variable = expression->variable();
- if (!variable) {
- return "<invalid variable ID>";
- }
- return variable->name;
-}
-
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
TensorShape* shape) const {
const Tensor& tensor = context_->input(index);
const XlaExpression* expression = CastExpressionFromTensor(tensor);
- XlaVariable* variable = expression->variable();
+ XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
+ TF_RET_CHECK(variable->kind == XlaResource::kVariable);
if (variable->value.handle() == 0) {
return errors::InvalidArgument("Read of uninitialized variable ",
variable->name);
@@ -337,33 +329,34 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
expression->set_constant_value(constant);
}
-void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) {
+void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
Tensor* output = nullptr;
- // The shape of the output tensor is the shape of the variable resource
- // (i.e., a scalar), not the shape of the variable's value.
+ // The shape of the output tensor is the shape of the resource itself
+ // (i.e., a scalar), not the shape of the resource's value.
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape(), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
- expression->set_variable(variable);
+ expression->set_resource(resource);
}
-Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) {
+Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
const XlaExpression* expression =
CastExpressionFromTensor(context_->input(index));
- TF_RET_CHECK(expression->variable() != nullptr);
- *variable = expression->variable();
+ TF_RET_CHECK(expression->resource() != nullptr);
+ *resource = expression->resource();
return Status::OK();
}
Status XlaOpKernelContext::AssignVariable(
- int index, DataType type, const xla::ComputationDataHandle& handle) {
+ int input_index, DataType type, const xla::ComputationDataHandle& handle) {
TF_RET_CHECK(handle.handle() != 0);
SetOpHasSideEffects();
const XlaExpression* expression =
- CastExpressionFromTensor(context_->input(index));
- XlaVariable* variable = expression->variable();
+ CastExpressionFromTensor(context_->input(input_index));
+ XlaResource* variable = expression->resource();
TF_RET_CHECK(variable != nullptr);
+ TF_RET_CHECK(variable->kind == XlaResource::kVariable);
if (!((variable->type == DT_INVALID && type != DT_INVALID) ||
(variable->type == type))) {
return errors::InvalidArgument(
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index a25774c3a6..b151286217 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -148,6 +148,12 @@ class XlaOpKernelContext {
// Variables
+ // Sets '*resource' to the resource associated with input `index`.
+ Status GetResourceInput(int index, XlaResource** resource);
+
+ // Sets output 'index' to be a reference to resource 'resource'.
+ void SetResourceOutput(int index, XlaResource* resource);
+
// Sets `*type` and `*shape` to the current type and shape of a variable's
// value.
Status GetVariableTypeAndShape(int index, DataType* type,
@@ -158,20 +164,10 @@ class XlaOpKernelContext {
Status ReadVariableInput(int index, xla::ComputationDataHandle* value);
// Assigns the value `handle` to the variable referenced by input
- // `variable_index`. Marks the operator as having side effects.
- Status AssignVariable(int variable_index, DataType type,
+ // `input_index`. Marks the operator as having side effects.
+ Status AssignVariable(int input_index, DataType type,
const xla::ComputationDataHandle& handle);
- // Sets '*variable' to the variable associated with input `index`.
- Status GetVariableInput(int index, XlaVariable** variable);
-
- // Sets output 'index' to be a reference to variable 'variable'. Used
- // to propagate resource variables through the compilation.
- void SetVariableOutput(int index, XlaVariable* variable);
-
- // Returns a human-readable debug string describing 'variable_index'.
- string VariableDebugString(int variable_index);
-
// Helper routines for the OP_REQUIRES macros
void CtxFailure(Status s);
void CtxFailureWithWarning(Status s);