diff options
author | Peter Hawkins <phawkins@google.com> | 2017-06-16 10:32:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-16 10:36:33 -0700 |
commit | a66de1eca225bc95e7972974a7089d84df8a8055 (patch) | |
tree | 7a332e8fd1f1656a0ef0caacbfff0851ab13967f | |
parent | d06121593035687176d8b660b83bab568853deff (diff) |
[TF:XLA] Refactor handling of Resources (Variables and TensorArrays) in the XLA bridge.
* Rename "Variable" to "Resource" in many places where non-Variable resources might be used.
* Add kTensorArray to the XlaCompiler::Argument enum. Remove kUninitializedVariable and make "initialized" a separate boolean field.
* Add a kind field to XlaResource. Add checks that Variables are not used where TensorArrays are expected, and vice-versa.
* Clean ups to the TensorArray operators.
PiperOrigin-RevId: 159244478
-rw-r--r-- | tensorflow/compiler/jit/kernels/xla_device_launch_op.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_compilation_cache.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/tests/tensor_array_ops_test.py | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/arg_op.cc | 25 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 205 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/while_op.cc | 89 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compilation_device.h | 35 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 73 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 39 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.cc | 20 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.h | 26 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 39 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.h | 20 |
13 files changed, 327 insertions, 255 deletions
diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc index 29c5ff7242..e786fdb16d 100644 --- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -202,8 +202,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { // Apply variable updates, if any. VLOG(2) << "Applying variable updates"; - for (int i = 0; i < kernel->variable_updates.size(); ++i) { - const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i]; + for (int i = 0; i < kernel->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 63ca77f9a9..2325217b97 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -182,17 +182,18 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.name = variable_args[variable_id].name; + arg.kind = XlaCompiler::Argument::kVariable; if (variable_args[variable_id].present) { const Tensor& value = variable_args[variable_id].value; - arg.kind = XlaCompiler::Argument::kVariable; arg.type = value.dtype(); arg.shape = value.shape(); + arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since // they are meaningless. However, it is legal to assign to a resource // variable for the first time inside the XLA computation, so we do permit // uninitialized variables. - arg.kind = XlaCompiler::Argument::kUninitializedVariable; + arg.initialized = false; arg.type = DT_INVALID; arg.shape = TensorShape(); } diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 00a7358130..b3067be51d 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -335,7 +335,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0_bad = gen_data_flow_ops._tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) with self.assertRaisesOpError( - "TensorArray dtype is float but Op requested dtype double."): + "TensorArray dtype is float but op has dtype double."): r0_bad.eval() # Test reading from a different index than the one we wrote to diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 620fc84437..6ad72c6219 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -51,13 +51,26 @@ class ArgOp : public XlaOpKernel { XlaContext& xc = XlaContext::Get(ctx); const XlaContext::Argument& arg = xc.args()[index_]; - if (arg.is_variable) { + if (arg.is_resource) { + XlaResource::Kind kind; + switch (arg.kind) { + case XlaCompiler::Argument::kVariable: + kind = XlaResource::kVariable; + break; + case XlaCompiler::Argument::kTensorArray: + kind = XlaResource::kTensorArray; + break; + default: + CHECK(false); + } + // TODO(phawkins): this code assumes that variables do not alias. - XlaVariable* var; - OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type, - arg.value.handle, &var)); - var->tensor_array_size = arg.tensor_array_size; - ctx->SetVariableOutput(0, var); + XlaResource* resource; + OP_REQUIRES_OK(ctx, + xc.CreateResource(kind, index_, arg.name, arg.value.type, + arg.value.handle, &resource)); + resource->tensor_array_size = arg.tensor_array_size; + ctx->SetResourceOutput(0, resource); } else if (arg.value.is_constant) { ctx->SetConstantOutput(0, arg.value.constant_value); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index c7510bf3d2..598b341002 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -41,32 +41,36 @@ namespace { // Since the element shape is not always provided to the TensorArrayV3 operator, // we must support lazily initialization of the TensorArray at the time of the // first write. -// If a TensorArray `var` has not been initialized, constructs storage for the -// TensorArray with elements of `elem_shape`. For both initialized and +// If a TensorArray `resource` has not been initialized, constructs storage for +// the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, - XlaVariable* var, DataType dtype, + XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (var->type != dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument("Unexpected non-TensorArray resource"); + } + + if (resource->type != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(var->type), + "TensorArray dtype is ", DataTypeString(resource->type), " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(var->tensor_array_size >= 0) - << var->name << " size " << var->tensor_array_size; + TF_RET_CHECK(resource->tensor_array_size >= 0) + << resource->name << " size " << resource->tensor_array_size; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - if (var->value.handle() == 0) { + if (resource->value.handle() == 0) { // TensorArray has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type); - var->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); + resource->value = builder->Broadcast(zero, ta_shape.dim_sizes()); } else { // Checks the elem_shape matches the TensorArray shape. - auto shape_or_status = builder->GetShape(var->value); + auto shape_or_status = builder->GetShape(resource->value); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -80,6 +84,44 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, return Status::OK(); } +// Checks that the TensorArray 'resource' has been initialized, and has type +// 'dtype'. Sets 'shape' to the shape +Status CheckTensorArrayIsInitialized(const string& op_name, + const XlaResource* resource, + DataType dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument( + "Unexpected non-TensorArray resource passed " + "to ", + op_name); + } + if (resource->value.handle() == 0) { + return errors::InvalidArgument("Uninitialized TensorArray passed to ", + op_name); + } + if (resource->type != dtype) { + return errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(resource->type), + " but op has dtype ", DataTypeString(dtype), "."); + } + + return Status::OK(); +} + +Status GetTensorArrayShape(const XlaResource* resource, + xla::ComputationBuilder* builder, + TensorShape* shape) { + auto shape_or_status = builder->GetShape(resource->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + *shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + if (shape->dims() < 1) { + return errors::InvalidArgument("TensorArray rank must be >= 1"); + } + return Status::OK(); +} + // Pads 'x' with 'count' zero indices. 'x' must have 1 element. xla::ComputationDataHandle PadIndexWithZeros( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, @@ -125,7 +167,6 @@ class TensorArrayOp : public XlaOpKernel { errors::InvalidArgument("TensorArray size must be >= 0")); xla::ComputationBuilder* b = ctx->builder(); - b->set_die_immediately_on_error(true); // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. @@ -141,12 +182,13 @@ class TensorArrayOp : public XlaOpKernel { } XlaContext& xc = XlaContext::Get(ctx); - XlaVariable* var; + XlaResource* var; string name = strings::StrCat("TensorArray: ", tensor_array_name_); - OP_REQUIRES_OK(ctx, - xc.CreateVariable(-1, std::move(name), dtype_, value, &var)); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + dtype_, value, &var)); var->tensor_array_size = size; - ctx->SetVariableOutput(0, var); + ctx->SetResourceOutput(0, var); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } @@ -173,11 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel { // Initializes the TensorArray, if the element shape was not known at // construction time. - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); @@ -191,7 +234,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written)); + resource->value = written; ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -210,20 +253,17 @@ class TensorArrayReadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(ta_type), - " but Op requested dtype ", DataTypeString(dtype_), ".")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. @@ -255,13 +295,15 @@ class TensorArrayGatherOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; + xla::ComputationBuilder* b = ctx->builder(); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -269,10 +311,7 @@ class TensorArrayGatherOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); auto indices = ctx->Input(1); - xla::ComputationBuilder* b = ctx->builder(); - - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + xla::ComputationDataHandle ta = resource->value; // For each index in `indices`, add the corresponding slice to `slices`. std::vector<xla::ComputationDataHandle> slices(num_indices); @@ -320,11 +359,12 @@ class TensorArrayScatterOp : public XlaOpKernel { const TensorShape value_shape = ctx->InputShape(2); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); TensorShape elem_shape = value_shape; elem_shape.RemoveDim(0); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -332,7 +372,7 @@ class TensorArrayScatterOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); const xla::ComputationDataHandle indices = ctx->Input(1); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; const xla::ComputationDataHandle value = ctx->Input(2); auto slice_dims = value_shape.dim_sizes(); @@ -355,7 +395,7 @@ class TensorArrayScatterOp : public XlaOpKernel { ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = ta; ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -374,18 +414,17 @@ class TensorArrayConcatOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; auto ta_dims = ta_shape.dim_sizes(); std::vector<int64> shape(ta_dims.begin() + 1, ta_dims.end()); @@ -436,19 +475,20 @@ class TensorArraySplitOp : public XlaOpKernel { elem_shape.set_dim(0, length); xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); + xla::ComputationDataHandle ta = resource->value; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size, + OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size, errors::InvalidArgument( "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", var->tensor_array_size, ")")); + lengths.size(), " vs. ", resource->tensor_array_size, ")")); const xla::ComputationDataHandle value = ctx->Input(1); @@ -457,8 +497,7 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -476,8 +515,8 @@ class TensorArraySizeOp : public XlaOpKernel { explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* var; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); size_tensor.scalar<int32>()() = static_cast<int32>(var->tensor_array_size); ctx->SetConstantOutput(0, size_tensor); @@ -498,31 +537,31 @@ class TensorArrayGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - DataType ta_type; + OP_REQUIRES_OK( + ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); // Finds or looks up the corresponding gradient TensorArray, which stores // gradients computed during backpropagation. - XlaVariable*& gradient = var->tensor_array_gradient[source_]; + XlaResource*& gradient = resource->tensor_array_gradient[source_]; if (!gradient) { - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type); xla::ComputationDataHandle value = b->Broadcast(zero, ta_shape.dim_sizes()); XlaContext& xc = XlaContext::Get(ctx); - string name = strings::StrCat("TensorArrayGrad: ", var->name); - OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type, - value, &gradient)); - gradient->tensor_array_size = var->tensor_array_size; + string name = strings::StrCat("TensorArrayGrad: ", resource->name); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + resource->type, value, &gradient)); + gradient->tensor_array_size = resource->tensor_array_size; } - ctx->SetVariableOutput(0, gradient); + ctx->SetResourceOutput(0, gradient); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ae9a358e22..0caa9c5f37 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -42,29 +42,38 @@ Status MakeXlaCompilerArgumentsFromInputs( << " shape: " << ctx->InputShape(i).DebugString(); XlaCompiler::Argument& arg = (*args)[i]; DataType type = ctx->input_type(i); - // When reading a variable input, use the type and shape of the variable's + // When reading a resource input, use the type and shape of the resource's // current value. if (type == DT_RESOURCE) { - XlaVariable* var; - TF_RETURN_IF_ERROR(ctx->GetVariableInput(i, &var)); - - bool initialized = var->value.handle() > 0; - if (initialized) { - arg.kind = XlaCompiler::Argument::kVariable; - TF_RETURN_IF_ERROR( - ctx->GetVariableTypeAndShape(i, &arg.type, &arg.shape)); + XlaResource* resource; + TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource)); + + arg.initialized = resource->value.handle() > 0; + switch (resource->kind) { + case XlaResource::kVariable: + arg.kind = XlaCompiler::Argument::kVariable; + break; + case XlaResource::kTensorArray: + arg.kind = XlaCompiler::Argument::kTensorArray; + break; + case XlaResource::kInvalid: + CHECK(false); + } + arg.type = resource->type; + if (arg.initialized) { + auto shape = ctx->builder()->GetShape(resource->value); + TF_RETURN_IF_ERROR(shape.status()); + arg.shape = XLAShapeToTensorShape(*shape.ValueOrDie()); } else { - arg.kind = XlaCompiler::Argument::kUninitializedVariable; - arg.type = var->type; *has_uninitialized_vars = true; } - arg.tensor_array_size = var->tensor_array_size; - arg.name = var->name; + arg.tensor_array_size = resource->tensor_array_size; + arg.name = resource->name; // TODO(phawkins): propagate TensorArray gradients into loops. - VLOG(2) << " variable " << var->name + VLOG(2) << " resource " << resource->name << " type: " << DataTypeString(arg.type) << " shape: " << arg.shape.DebugString() - << " initialized: " << initialized; + << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; @@ -86,7 +95,7 @@ XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { - VLOG(1) << "WhileOp::Compute"; + VLOG(1) << "WhileOp::Compile"; std::vector<XlaCompiler::Argument> arguments; bool has_uninitialized_vars; @@ -100,16 +109,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { VLOG(1) << "Compiling body"; - // All resource variables that are inputs to the loop's body must also be + // All resource that are inputs to the loop's body must also be // present as loop body outputs; the signature of the loop's input and // output must match. We ensure this by asking the compiler to include the - // current values of all variables, even if they haven't been updated by the + // current values of all resources, even if they haven't been updated by the // computation. // TODO(phawkins): consider adding loop-invariant inputs to XLA's While() // operator. XlaCompiler::CompileOptions body_options; body_options.use_tuple_arg = use_tuple_arg; - body_options.return_updated_values_for_all_variables = true; + body_options.return_updated_values_for_all_resources = true; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -118,26 +127,28 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // we may not know the shape of a TensorArray if it is first written inside // the loop. Ideally we would require the user to provide a static shape, // but this is not always easy. - // So if uninitialized variables are used by the loop body, we compile the + // So if uninitialized resource are used by the loop body, we compile the // body function twice: - // 1) once with uninitialized variable inputs. We discard the computation - // but we assume variable shapes reach a fixpoint after one iteration. - // So we can use the output shapes of the variables as the "true" shapes. + // 1) once with uninitialized resource inputs. We discard the computation + // but we assume resource shapes reach a fixpoint after one iteration. + // So we can use the output shapes of the resource as the "true" shapes. // 2) again with the "correct" input shapes determined by (1). if (has_uninitialized_vars) { - // Initializes any uninitialized variables with zero values of the + // Initializes any uninitialized resource with zero values of the // shape determined by the first compilation. - for (int i = 0; i < body.variable_updates.size(); ++i) { - const XlaCompiler::VariableUpdate& update = body.variable_updates[i]; + for (int i = 0; i < body.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; XlaCompiler::Argument& arg = arguments[update.input_index]; - if (arg.kind == XlaCompiler::Argument::kUninitializedVariable) { - arg.kind = XlaCompiler::Argument::kVariable; + if (!arg.initialized) { + arg.initialized = true; arg.shape = update.shape; + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index, &resource)); + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, arg.type); - auto value = builder->Broadcast(zero, update.shape.dim_sizes()); - OP_REQUIRES_OK( - ctx, ctx->AssignVariable(update.input_index, arg.type, value)); + resource->value = builder->Broadcast(zero, update.shape.dim_sizes()); } } // Recompile the body with the "correct" shapes. @@ -191,7 +202,9 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { for (int i = 0; i < num_inputs; ++i) { int input_num = body.input_mapping[i]; if (ctx->input_type(input_num) == DT_RESOURCE) { - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(input_num, &inputs[i])); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + inputs[i] = resource->value; } else { inputs[i] = ctx->Input(i); } @@ -225,16 +238,16 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { } // Updates the values of any resource variables modified by the loop. - for (int i = 0; i < body.variable_updates.size(); ++i) { - const XlaCompiler::VariableUpdate& update = body.variable_updates[i]; + for (int i = 0; i < body.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { int pos = body.outputs.size() + i; - OP_REQUIRES_OK(ctx, ctx->AssignVariable(update.input_index, update.type, - get_loop_output(pos))); + resource->value = get_loop_output(pos); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index - << " name: " << ctx->VariableDebugString(update.input_index) - << " modified: " << update.modified + << " name: " << resource->name << " modified: " << update.modified << " type: " << DataTypeString(update.type) << " shape: " << update.shape.DebugString(); // Copies the identity of the resource variable from input to output diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 75630bee39..e4f43f1950 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -64,26 +64,35 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr<XlaCompilationAllocator> allocator_; }; -struct XlaVariable { - // If this variable is visible externally, what was its argument number? +// Represents a resource, such as a Variable or TensorArray. +struct XlaResource { + enum Kind { + kInvalid, + kVariable, + kTensorArray, + }; + + Kind kind = kInvalid; + + // If this resource is visible externally, what was its argument number? int arg_num = -1; - // A descriptive name for the variable, used in error messages. + // A descriptive name for the resource, used in error messages. string name; - // Current type and value of the variable. Uninitialized variables are + // Current type and value of the resource. Uninitialized resources are // represented by a default (zero) handle and type DT_INVALID. - // While the type of a variable is notionally fixed during execution, when - // a variable is first initialized we do not yet know its type, so we keep + // While the type of a resource is notionally fixed during execution, when + // a resource is first initialized we do not yet know its type, so we keep // track of its type dynamically. DataType type = DT_INVALID; xla::ComputationDataHandle value; - // Value of the variable at computation entry. Used to detect which + // Value of the resource at computation entry. Used to detect which // variables have new values that need to be written back. xla::ComputationDataHandle initial_value; - // We treat TensorArrays as a Variable with some extra metadata. + // TensorArray-specific fields // 'tensor_array_size' stores the expected size of the TensorArray. We need // to store this since sometimes TensorArrays must be initialized lazily since @@ -91,10 +100,10 @@ struct XlaVariable { int64 tensor_array_size = -1; // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes - // to an XlaVariable containing the gradient TensorArrays. We store a pointer + // to an XlaResource containing the gradient TensorArrays. We store a pointer // here since there should only be one gradient TensorArray per 'source' // string, irrespective of the number of calls to TensorArrayGrad. - std::unordered_map<string, XlaVariable*> tensor_array_gradient; + std::unordered_map<string, XlaResource*> tensor_array_gradient; }; // A XlaExpression wraps an XLA computation. Each Tensor on an @@ -115,8 +124,8 @@ class XlaExpression { bool has_constant_value() const { return has_constant_value_; } const Tensor& constant_value() const { return constant_value_; } - void set_variable(XlaVariable* variable) { variable_ = variable; } - XlaVariable* variable() const { return variable_; } + void set_resource(XlaResource* resource) { resource_ = resource; } + XlaResource* resource() const { return resource_; } private: // The XLA handle of the expression's computation. @@ -128,7 +137,7 @@ class XlaExpression { bool has_constant_value_ = false; Tensor constant_value_; - XlaVariable* variable_ = nullptr; // Not owned. + XlaResource* resource_ = nullptr; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index d7c92a710c..50b384997a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -251,35 +251,36 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, std::vector<xla::Shape>* input_shapes) { context_args->resize(args.size()); - // Argument numbers of arguments and variables that are to be passed to the + // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. - std::vector<int> parameters, variables; + std::vector<int> parameters, resources; parameters.reserve(args.size()); - variables.reserve(args.size()); + resources.reserve(args.size()); for (std::vector<XlaCompiler::Argument>::size_type i = 0; i < args.size(); ++i) { XlaContext::Argument& context_arg = (*context_args)[i]; + context_arg.kind = args[i].kind; context_arg.name = args[i].name; context_arg.value.constant_value = args[i].constant_value; context_arg.value.type = args[i].type; switch (args[i].kind) { case XlaCompiler::Argument::kVariable: - variables.push_back(i); - context_arg.is_variable = true; - context_arg.value.is_constant = false; + case XlaCompiler::Argument::kTensorArray: + context_arg.is_resource = true; + if (args[i].initialized) { + resources.push_back(i); + context_arg.value.is_constant = false; + } else { + context_arg.value.is_constant = true; + } context_arg.tensor_array_size = args[i].tensor_array_size; break; case XlaCompiler::Argument::kParameter: parameters.push_back(i); context_arg.value.is_constant = false; break; - case XlaCompiler::Argument::kUninitializedVariable: - context_arg.is_variable = true; - context_arg.value.is_constant = true; - context_arg.tensor_array_size = args[i].tensor_array_size; - break; case XlaCompiler::Argument::kConstant: context_arg.value.is_constant = true; break; @@ -290,7 +291,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, // Append parameters containing variable values after the other runtime // parameters. - parameters.insert(parameters.end(), variables.begin(), variables.end()); + parameters.insert(parameters.end(), resources.begin(), resources.end()); if (parameters.empty()) { return Status::OK(); } @@ -331,22 +332,22 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, // variable states, generated by the symbolic evaluation. // If `has_side_effects` is true, the computation has side effects and should be // built even if it has no outputs. -// If `return_updated_values_for_all_variables` is true, all variables will be -// included in `variable_updates`, regardless of whether their value changed. +// If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*variable_updates` to a description of variables whose values are +// Sets `*resource_updates` to a description of resources whose values are // written by the computation; the variable writes are the last -// `variable_updates.size()` return values from the computation. Each entry in -// `variable_updates` is a (input_index, type) pair, where `input_index` is the +// `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a (input_index, type) pair, where `input_index` is the // index of a resource variable argument to the computation, and `type` is the // type of the final output. Status BuildComputation( const std::vector<XlaContext::HandleOrConstant>& retvals, - const std::vector<std::unique_ptr<XlaVariable>>& variables, - bool has_side_effects, bool return_updated_values_for_all_variables, + const std::vector<std::unique_ptr<XlaResource>>& resources, + bool has_side_effects, bool return_updated_values_for_all_resources, xla::ComputationBuilder* builder, xla::Computation* computation, int* num_nonconst_outputs, - std::vector<XlaCompiler::VariableUpdate>* variable_updates) { + std::vector<XlaCompiler::ResourceUpdate>* resource_updates) { std::vector<xla::ComputationDataHandle> elems; elems.reserve(retvals.size()); for (const XlaContext::HandleOrConstant& retval : retvals) { @@ -356,24 +357,24 @@ Status BuildComputation( } *num_nonconst_outputs = elems.size(); - // Add return values for variables whose values have changed. - std::vector<const XlaVariable*> arg_vars; - arg_vars.reserve(variables.size()); - for (const auto& var : variables) { + // Add return values for resources whose values have changed. + std::vector<const XlaResource*> arg_vars; + arg_vars.reserve(resources.size()); + for (const auto& var : resources) { if (var->arg_num >= 0) { arg_vars.push_back(var.get()); } } std::sort(arg_vars.begin(), arg_vars.end(), - [](const XlaVariable* a, const XlaVariable* b) { + [](const XlaResource* a, const XlaResource* b) { return a->arg_num < b->arg_num; }); - for (const XlaVariable* var : arg_vars) { + for (const XlaResource* var : arg_vars) { bool modified = var->value.handle() != var->initial_value.handle(); - if (return_updated_values_for_all_variables || modified) { - variable_updates->emplace_back(); - XlaCompiler::VariableUpdate& update = variable_updates->back(); + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); update.input_index = var->arg_num; update.type = var->type; update.modified = modified; @@ -439,10 +440,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; result->computation = std::make_shared<xla::Computation>(); TF_RETURN_IF_ERROR(BuildComputation( - context->retvals(), context->variables(), context->has_side_effects(), - options.return_updated_values_for_all_variables, &builder, + context->retvals(), context->resources(), context->has_side_effects(), + options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_nonconst_outputs, - &result->variable_updates)); + &result->resource_updates)); result->requires_runtime_context = context->has_context_parameter(); @@ -517,15 +518,15 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, } } - for (std::vector<VariableUpdate>::size_type i = 0; - i < result->variable_updates.size(); ++i) { + for (std::vector<ResourceUpdate>::size_type i = 0; + i < result->resource_updates.size(); ++i) { if (num_computation_outputs > 1) { - result->variable_updates[i].shape = + result->resource_updates[i].shape = XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( result->xla_output_shape, computation_output)); } else { CHECK_EQ(0, computation_output); - result->variable_updates[i].shape = + result->resource_updates[i].shape = XLAShapeToTensorShape(result->xla_output_shape); } ++computation_output; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 6b9bf4159d..58e42c3474 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -85,14 +85,14 @@ class XlaCompiler { // Argument is a compile-time constant. No associated runtime parameter. kConstant, - // Argument is a variable that has not been initialized yet. No associated - // runtime parameter. - kUninitializedVariable, - - // Argument is a variable that already has a value set. Expects a runtime - // parameter containing the current value. + // Argument is a variable resource. Has an associated runtime parameter + // iff `initialized` is true. kVariable, + // Argument is a TensorArray resource. Has an associated runtime parameter + // iff `initialized` is true. + kTensorArray, + // Argument is a run-time parameter. kParameter, }; @@ -114,8 +114,11 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; - // For a kVariable or kUninitializedVariable corresponding to a TensorArray, - // what is the tensor array's declared size? + // For a kVariable or kTensorArray, has this resource been initialized? + bool initialized = false; + + // For a kTensorArray, what is the array's declared size? (Used for lazy + // initialization.) int64 tensor_array_size = -1; bool operator==(const Argument& other) const; @@ -133,7 +136,7 @@ class XlaCompiler { }; // Describes a variable write side effect of the computation. - struct VariableUpdate { + struct ResourceUpdate { // Index of the input that contains the variable resource to write to. int input_index; @@ -142,14 +145,14 @@ class XlaCompiler { TensorShape shape; // Was the value of the variable modified by the computation? - // (Always true, unless `return_updated_values_for_all_variables` is true.) + // (Always true, unless `return_updated_values_for_all_resources` is true.) bool modified; }; struct CompilationResult { // Vector that maps from the parameters of the XLA computation to their // original argument positions. To handle compile-time constant inputs and - // variables, the parameters to the XLA computation may be a subset of the + // resources, the parameters to the XLA computation may be a subset of the // original arguments, and are not necessarily in the same order.) std::vector<int> input_mapping; @@ -172,10 +175,10 @@ class XlaCompiler { // containing both constant and non-constant results. std::vector<OutputDescription> outputs; - // Variables whose values were updated by the computation, ordered - // by return value position. Variable updates follow the non-constant + // Resources whose values were updated by the computation, ordered + // by return value position. Resource updates follow the non-constant // results in the outputs of XLA computation. - std::vector<VariableUpdate> variable_updates; + std::vector<ResourceUpdate> resource_updates; // The XLA computation built from the tensorflow subgraph. May be null // if the output consists solely of compile-time constants. @@ -229,12 +232,12 @@ class XlaCompiler { // arguments; if false, each argument gets its own parameter. bool use_tuple_arg = false; - // If 'return_updated_values_for_all_variables' is true, then updated - // values of all resource variables arguments will be included in the - // 'variable_updates' of the computation, even if the variable was not + // If 'return_updated_values_for_all_resources' is true, then updated + // values of all resource resources arguments will be included in the + // 'resource_updates' of the computation, even if the resource was not // modified by the computation. Used when compiling loop bodies to ensure // the input and output signatures match. - bool return_updated_values_for_all_variables = false; + bool return_updated_values_for_all_resources = false; }; // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 4440b53069..1a37d61944 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -129,16 +129,18 @@ void XlaContext::AddSideEffects() { xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateVariable(int arg_num, string name, DataType type, +Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, + string name, DataType type, const xla::ComputationDataHandle& handle, - XlaVariable** variable) { - variables_.emplace_back(new XlaVariable); - *variable = variables_.back().get(); - XlaVariable& var = **variable; - var.arg_num = arg_num; - var.name = std::move(name); - var.type = type; - var.initial_value = var.value = handle; + XlaResource** resource) { + resources_.emplace_back(new XlaResource); + *resource = resources_.back().get(); + XlaResource& r = **resource; + r.kind = kind; + r.arg_num = arg_num; + r.name = std::move(name); + r.type = type; + r.initial_value = r.value = handle; return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3978baaf63..dbede52b5d 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -52,11 +52,13 @@ class XlaContext : public ResourceBase { }; struct Argument { - // Descriptive name for the variable, for use in error messages. + XlaCompiler::Argument::Kind kind; + + // Descriptive name for the resource, for use in error messages. string name; - // Is this a variable? - bool is_variable = false; + // Is this a resource? + bool is_resource = false; HandleOrConstant value; @@ -106,15 +108,15 @@ class XlaContext : public ResourceBase { bool has_side_effects() const { return has_side_effects_; } - // Creates a variable with variable `variable_id` and initial type `type` and + // Creates a resource with resource `kind` and initial type `type` and // value `handle`. `name` is a descriptive name for use in error messages. - // Fails if the variable already exists. - Status CreateVariable(int arg_num, string name, DataType type, - const xla::ComputationDataHandle& handle, - XlaVariable** variable); + // Fails if the resource already exists. + Status CreateResource(XlaResource::Kind kind, int arg_num, string name, + DataType type, const xla::ComputationDataHandle& handle, + XlaResource** resource); - const std::vector<std::unique_ptr<XlaVariable>>& variables() { - return variables_; + const std::vector<std::unique_ptr<XlaResource>>& resources() { + return resources_; } // Get an XLA lambda to compute Max. This is cached in the @@ -166,8 +168,8 @@ class XlaContext : public ResourceBase { // Does the computation have side effects, i.e., Send() calls? bool has_side_effects_ = false; - // Holds ownership of variables. The variables are not ordered. - std::vector<std::unique_ptr<XlaVariable>> variables_; + // Holds ownership of resources. The resources are not ordered. + std::vector<std::unique_ptr<XlaResource>> resources_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map<DataType, xla::Computation>; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 3272b1efa1..edb7e2a563 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -39,7 +39,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data()); CHECK(expression->handle().handle() != 0 || - expression->variable() != nullptr); + expression->resource() != nullptr); VLOG(1) << "Fetched T" << expression->handle().handle(); return expression; } @@ -252,8 +252,9 @@ Status XlaOpKernelContext::ReadVariableInput( int index, xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -262,22 +263,13 @@ Status XlaOpKernelContext::ReadVariableInput( return Status::OK(); } -string XlaOpKernelContext::VariableDebugString(int index) { - const Tensor& tensor = context_->input(index); - const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); - if (!variable) { - return "<invalid variable ID>"; - } - return variable->name; -} - Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -337,33 +329,34 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } -void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) { +void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Tensor* output = nullptr; - // The shape of the output tensor is the shape of the variable resource - // (i.e., a scalar), not the shape of the variable's value. + // The shape of the output tensor is the shape of the resource itself + // (i.e., a scalar), not the shape of the resource's value. OP_REQUIRES_OK(context_, context_->allocate_output(index, TensorShape(), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_variable(variable); + expression->set_resource(resource); } -Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) { +Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { const XlaExpression* expression = CastExpressionFromTensor(context_->input(index)); - TF_RET_CHECK(expression->variable() != nullptr); - *variable = expression->variable(); + TF_RET_CHECK(expression->resource() != nullptr); + *resource = expression->resource(); return Status::OK(); } Status XlaOpKernelContext::AssignVariable( - int index, DataType type, const xla::ComputationDataHandle& handle) { + int input_index, DataType type, const xla::ComputationDataHandle& handle) { TF_RET_CHECK(handle.handle() != 0); SetOpHasSideEffects(); const XlaExpression* expression = - CastExpressionFromTensor(context_->input(index)); - XlaVariable* variable = expression->variable(); + CastExpressionFromTensor(context_->input(input_index)); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (!((variable->type == DT_INVALID && type != DT_INVALID) || (variable->type == type))) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a25774c3a6..b151286217 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -148,6 +148,12 @@ class XlaOpKernelContext { // Variables + // Sets '*resource' to the resource associated with input `index`. + Status GetResourceInput(int index, XlaResource** resource); + + // Sets output 'index' to be a reference to resource 'resource'. + void SetResourceOutput(int index, XlaResource* resource); + // Sets `*type` and `*shape` to the current type and shape of a variable's // value. Status GetVariableTypeAndShape(int index, DataType* type, @@ -158,20 +164,10 @@ class XlaOpKernelContext { Status ReadVariableInput(int index, xla::ComputationDataHandle* value); // Assigns the value `handle` to the variable referenced by input - // `variable_index`. Marks the operator as having side effects. - Status AssignVariable(int variable_index, DataType type, + // `input_index`. Marks the operator as having side effects. + Status AssignVariable(int input_index, DataType type, const xla::ComputationDataHandle& handle); - // Sets '*variable' to the variable associated with input `index`. - Status GetVariableInput(int index, XlaVariable** variable); - - // Sets output 'index' to be a reference to variable 'variable'. Used - // to propagate resource variables through the compilation. - void SetVariableOutput(int index, XlaVariable* variable); - - // Returns a human-readable debug string describing 'variable_index'. - string VariableDebugString(int variable_index); - // Helper routines for the OP_REQUIRES macros void CtxFailure(Status s); void CtxFailureWithWarning(Status s); |