diff options
author | 2017-09-19 19:08:19 -0700 | |
---|---|---|
committer | 2017-09-19 19:12:43 -0700 | |
commit | 1f20a786d69c4b91a4015fe3f4df8c23bd345f40 (patch) | |
tree | 9175a24a490a21587cb899b4dda98f11f83f948c /tensorflow/compiler/tf2xla | |
parent | 5ce3523bcc844217b47e7f862c1bed894cbaa34e (diff) |
[TF:XLA] Add support for reading and writing TensorArray gradients in a while loop.
Previously, there was no code to handle propagating the values of a TensorArray's gradients into and out of loops. This change passes TensorArray gradients into and out of loops by packing them up as a (base array, gradient values...) tuple.
PiperOrigin-RevId: 169338418
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc | 24 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/while_op.cc | 97 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compilation_device.cc | 100 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compilation_device.h | 43 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 65 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 38 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 141 |
7 files changed, 428 insertions, 80 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 7f1597e9ad..c42d8b97ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -114,12 +115,7 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::ComputationBuilder* builder, TensorShape* shape) { - auto shape_or_status = builder->GetShape(resource->value); - if (!shape_or_status.ok()) { - return shape_or_status.status(); - } - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); + TF_RETURN_IF_ERROR(resource->GetShape(builder, shape)); if (shape->dims() < 1) { return errors::InvalidArgument("TensorArray rank must be >= 1"); } @@ -532,19 +528,9 @@ class TensorArrayGradOp : public XlaOpKernel { // Finds or looks up the corresponding gradient TensorArray, which stores // gradients computed during backpropagation. - XlaResource*& gradient = resource->tensor_array_gradient[source_]; - if (!gradient) { - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type); - xla::ComputationDataHandle value = - b->Broadcast(zero, ta_shape.dim_sizes()); - - XlaContext& xc = XlaContext::Get(ctx); - string name = strings::StrCat("TensorArrayGrad: ", resource->name); - OP_REQUIRES_OK( - ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - resource->type, value, &gradient)); - gradient->tensor_array_size = resource->tensor_array_size; - } + XlaResource* gradient; + OP_REQUIRES_OK( + ctx, resource->GetOrCreateTensorArrayGradient(source_, b, &gradient)); ctx->SetResourceOutput(0, gradient); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 55995aa421..ead26478ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -33,10 +33,11 @@ namespace { // Builds XlaCompiler argument descriptions `args` from `ctx`. Status MakeXlaCompilerArgumentsFromInputs( XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args, - bool* has_uninitialized_vars) { + bool* has_uninitialized_vars, bool* has_tensor_arrays) { VLOG(2) << "Num inputs " << ctx->num_inputs(); args->resize(ctx->num_inputs()); *has_uninitialized_vars = false; + *has_tensor_arrays = false; for (int i = 0; i < ctx->num_inputs(); ++i) { VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) @@ -52,20 +53,24 @@ Status MakeXlaCompilerArgumentsFromInputs( arg.initialized = resource->value.handle() > 0; arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = resource->kind; + if (arg.resource_kind == XlaResource::kTensorArray) { + *has_tensor_arrays = true; + } + arg.type = resource->type; if (arg.initialized) { - auto shape = ctx->builder()->GetShape(resource->value); - TF_RETURN_IF_ERROR(shape.status()); - arg.shape = *shape.ValueOrDie(); + TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape)); } else { *has_uninitialized_vars = true; } arg.tensor_array_size = resource->tensor_array_size; + for (const auto& gradient : resource->tensor_array_gradients) { + arg.tensor_array_gradients.insert(gradient.first); + } arg.name = resource->name; - // TODO(phawkins): propagate TensorArray gradients into loops. VLOG(2) << " resource " << resource->name << " type: " << DataTypeString(arg.type) - << " shape: " << arg.shape.DebugString() + << " shape: " << xla::ShapeUtil::HumanString(arg.shape) << " initialized: " << arg.initialized; } else { @@ -93,8 +98,10 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { std::vector<XlaCompiler::Argument> arguments; bool has_uninitialized_vars; - OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs( - ctx, &arguments, &has_uninitialized_vars)); + bool has_tensor_arrays; + OP_REQUIRES_OK( + ctx, MakeXlaCompilerArgumentsFromInputs( + ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays)); xla::ComputationBuilder* builder = ctx->builder(); XlaCompiler* compiler = ctx->compiler(); @@ -118,38 +125,67 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { arguments, &body)); // We must use a static shape for parameters to an XLA compilation. However, - // we may not know the shape of a TensorArray if it is first written inside - // the loop. Ideally we would require the user to provide a static shape, - // but this is not always easy. - // So if uninitialized resource are used by the loop body, we compile the - // body function twice: - // 1) once with uninitialized resource inputs. We discard the computation - // but we assume resource shapes reach a fixpoint after one iteration. - // So we can use the output shapes of the resource as the "true" shapes. - // 2) again with the "correct" input shapes determined by (1). - if (has_uninitialized_vars) { + // we may not know the shape of a resource if it is first + // written inside the loop. Furthermore, we do not know ahead of time which + // gradient TensorArrays will be created by the TensorArrayGradV3 operator. + // + // Ideally we would change TensorFlow to provide static shape always, but + // but this is not easy to do. So if uninitialized resources or TensorArrays + // are used by the loop body, we compile the body function twice: + // 1) once with uninitialized resource inputs and no TensorArray gradient + // inputs. We then discard the computation but we assume resource shapes + // and the set of gradients read or written will reach a fixpoint after one + // iteration. + // Hence we can use the output shapes and TensorArray gradients of each + // resource as the "true" shapes. + // 2) again with the "correct" resource information determined by (1). + if (has_uninitialized_vars || has_tensor_arrays) { + VLOG(2) << "Recompiling loop body: has_uninitialized_vars: " + << has_uninitialized_vars + << " has_tensor_arrays: " << has_tensor_arrays; // Initializes any uninitialized resource with zero values of the // shape determined by the first compilation. for (int i = 0; i < body.resource_updates.size(); ++i) { const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); + XlaCompiler::Argument& arg = arguments[update.input_index]; if (!arg.initialized) { VLOG(2) << "Update shape for argument " << update.input_index << " " << xla::ShapeUtil::HumanString(update.shape); arg.initialized = true; - arg.shape = update.shape; - - XlaResource* resource; - OP_REQUIRES_OK(ctx, - ctx->GetResourceInput(update.input_index, &resource)); + xla::Shape shape = update.shape; + if (!update.tensor_array_gradients_accessed.empty()) { + shape = xla::ShapeUtil::GetTupleElementShape(shape, 0); + } std::unique_ptr<xla::Literal> zero = - xla::Literal::CreateFromShape(update.shape); + xla::Literal::CreateFromShape(shape); resource->value = builder->ConstantLiteral(*zero); } + + // Add any TensorArray gradients touched by the body to the enclosing + // graph. + for (const string& grad_source : update.tensor_array_gradients_accessed) { + VLOG(4) << "TensorArray " << resource->name << " accessed gradient " + << grad_source; + XlaResource* gradient; + OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( + grad_source, builder, &gradient)); + } + + // Add all of the TensorArray gradients to the argument. For simplicity, + // we always pass all known gradients. + for (const auto& gradient : resource->tensor_array_gradients) { + arg.tensor_array_gradients.insert(gradient.first); + } + + // Recompute the argument shape. + OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape)); } - // Recompile the body with the "correct" shapes. - VLOG(1) << "Recompiling body with non-placeholder shapes"; + // Recompile the body with the "correct" resource shapes. + VLOG(1) << "Recompiling body with corrected resource shapes"; body = {}; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -203,7 +239,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { if (ctx->input_type(input_num) == DT_RESOURCE) { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); - inputs[i] = resource->value; + OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder)); } else { inputs[i] = ctx->Input(i); } @@ -244,12 +280,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); if (update.modified) { int pos = body.outputs.size() + i; - resource->value = builder->GetTupleElement(while_result, pos); + OP_REQUIRES_OK(ctx, + resource->SetFromPack( + arguments[update.input_index].tensor_array_gradients, + builder->GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index << " name: " << resource->name << " modified: " << update.modified << " type: " << DataTypeString(update.type) - << " shape: " << update.shape.DebugString(); + << " shape: " << xla::ShapeUtil::HumanString(update.shape); // Copies the identity of the resource variable from input to output // unchanged, even if the variable was not modified. ctx->op_kernel_context()->set_output( diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 1d0098591e..4e6ef489f6 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -18,7 +18,9 @@ limitations under the License. #include <functional> #include <memory> +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/platform/mem.h" @@ -87,7 +89,7 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) { void XlaCompilationDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { - VLOG(1) << "XlaCompilationDevice::Compute " + VLOG(4) << "XlaCompilationDevice::Compute " << SummarizeNodeDef(op_kernel->def()); auto* b = XlaContext::Get(context).builder(); xla::OpMetadata metadata; @@ -96,7 +98,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, b->SetOpMetadata(metadata); op_kernel->Compute(context); b->ClearOpMetadata(); - VLOG(2) << "Done"; + VLOG(4) << "Done"; } Status XlaCompilationDevice::Sync() { return Status::OK(); } @@ -119,4 +121,98 @@ void XlaExpression::set_constant_value(Tensor value) { constant_value_ = std::move(value); } +Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder, + xla::Shape* shape) const { + auto shape_or_status = builder->GetShape(value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + *shape = *shape_or_status.ValueOrDie(); + return Status::OK(); +} + +Status XlaResource::GetShape(xla::ComputationBuilder* builder, + TensorShape* shape) const { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape)); + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape)); + return Status::OK(); +} + +Status XlaResource::GetOrCreateTensorArrayGradient( + const string& source, xla::ComputationBuilder* builder, + XlaResource** gradient_out) { + VLOG(2) << "Gradient lookup for resource: " << name + << " gradient: " << source; + TF_RET_CHECK(kind == kTensorArray); + std::unique_ptr<XlaResource>& gradient = tensor_array_gradients[source]; + if (!gradient) { + gradient.reset(new XlaResource); + gradient->kind = XlaResource::kTensorArray; + gradient->name = strings::StrCat("TensorArrayGrad: ", name); + gradient->type = type; + gradient->tensor_array_size = tensor_array_size; + + TensorShape ta_shape; + TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape)); + gradient->value = builder->Broadcast(XlaHelpers::Zero(builder, type), + ta_shape.dim_sizes()); + gradient->initial_value = gradient->value; + } + *gradient_out = gradient.get(); + return Status::OK(); +} + +Status XlaResource::PackedShape(xla::ComputationBuilder* builder, + xla::Shape* packed_shape) const { + if (tensor_array_gradients.empty()) { + return GetXlaShape(builder, packed_shape); + } + TF_RET_CHECK(kind == kTensorArray); + std::vector<xla::Shape> elem_shapes(1 + tensor_array_gradients.size()); + int pos = 0; + TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++])); + for (const auto& gradient : tensor_array_gradients) { + TF_RETURN_IF_ERROR( + gradient.second->GetXlaShape(builder, &elem_shapes[pos++])); + } + *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); + return Status::OK(); +} + +Status XlaResource::Pack(xla::ComputationDataHandle* pack, + xla::ComputationBuilder* builder) const { + if (tensor_array_gradients.empty()) { + *pack = value; + } else { + TF_RET_CHECK(kind == kTensorArray); + std::vector<xla::ComputationDataHandle> elems; + elems.push_back(value); + for (const auto& gradient : tensor_array_gradients) { + elems.push_back(gradient.second->value); + } + *pack = builder->Tuple(elems); + } + return Status::OK(); +} + +Status XlaResource::SetFromPack(const std::set<string>& gradient_sources, + const xla::ComputationDataHandle& pack, + xla::ComputationBuilder* builder) { + if (gradient_sources.empty()) { + value = pack; + } else { + TF_RET_CHECK(kind == kTensorArray); + int pos = 0; + value = builder->GetTupleElement(pack, pos++); + for (const auto& source : gradient_sources) { + XlaResource* gradient; + TF_RETURN_IF_ERROR( + GetOrCreateTensorArrayGradient(source, builder, &gradient)); + gradient->value = builder->GetTupleElement(pack, pos++); + } + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 22c24f4963..765683cf1d 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -18,6 +18,7 @@ limitations under the License. #include <memory> +#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" @@ -65,6 +66,7 @@ class XlaCompilationDevice : public LocalDevice { }; // Represents a resource, such as a Variable or TensorArray. +// TODO(phawkins): make this into a properly abstracted class. struct XlaResource { enum Kind { kInvalid, @@ -103,8 +105,45 @@ struct XlaResource { // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes // to an XlaResource containing the gradient TensorArrays. We store a pointer // here since there should only be one gradient TensorArray per 'source' - // string, irrespective of the number of calls to TensorArrayGrad. - std::unordered_map<string, XlaResource*> tensor_array_gradient; + // string, irrespective of the number of calls to TensorArrayGrad. The map + // is ordered since values are packed into tuples by Pack() sorted by name + // order. + std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients; + + // Returns the shape of the resource as an xla::Shape. + Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const; + + // Returns the shape of the resource as an TensorShape. Fails if the shape is + // not representable as a TensorShape. + Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const; + + // Looks up the gradient for `source`, or creates it if it does not already + // exist. The call target must be an initialized TensorArray resource. A + // TensorArray can have multiple named gradients; see the operator + // documentation for TensorArrayGradV3 for details. + Status GetOrCreateTensorArrayGradient(const string& source, + xla::ComputationBuilder* builder, + XlaResource** gradient_out); + + // Packs a resource into a single XLA value `pack`, suitable for use as + // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without + // gradients, sets `*pack` to `value`. + // For TensorArrays with gradients, packs the value and its gradient values in + // a tuple; the gradients values are packed in order by source name. + Status Pack(xla::ComputationDataHandle* pack, + xla::ComputationBuilder* builder) const; + + // Returns the shape of the `pack` value computed by `Pack()`. + Status PackedShape(xla::ComputationBuilder* builder, + xla::Shape* packed_shape) const; + + // Updates the resource with values from `pack`. If `gradient_sources` is + // non-empty, treats `pack` as a tuple that represents a TensorArray and + // its gradients, and unpacks and updates the gradient resources. Opposite + // of Pack(). + Status SetFromPack(const std::set<string>& gradient_sources, + const xla::ComputationDataHandle& pack, + xla::ComputationBuilder* builder); }; // A XlaExpression wraps an XLA computation. Each Tensor on an diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 08b9faad4a..34b1246be2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -60,8 +60,10 @@ Status CheckSignature(const DataTypeVector& types, bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, type, name, tensor_array_size) != - std::tie(other.kind, other.type, other.name, other.tensor_array_size)) { + if (std::tie(kind, resource_kind, type, name, tensor_array_size, + tensor_array_gradients) != + std::tie(other.kind, other.resource_kind, other.type, other.name, + other.tensor_array_size, other.tensor_array_gradients)) { return false; } if (!xla::ShapeUtil::Equal(shape, other.shape)) { @@ -303,15 +305,27 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, } // Fill in the handles in non-constant arguments. + VLOG(2) << "XLA computation inputs:"; for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) { const XlaCompiler::Argument& arg = args[parameters[i]]; + VLOG(2) << " XLA arg " << i + << " shape: " << xla::ShapeUtil::HumanString((*input_shapes)[i]) + << " name: " << arg.name << " TF arg " << parameters[i]; XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; switch (arg.kind) { - case XlaCompiler::Argument::kResource: + case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); - arg_expression.resource()->value = arg_handles[i]; - arg_expression.resource()->initial_value = arg_handles[i]; + XlaResource* resource = arg_expression.resource(); + TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients, + arg_handles[i], builder)); + VLOG(2) << " resource: num_gradients: " + << arg.tensor_array_gradients.size(); + resource->initial_value = resource->value; + for (const auto& gradient : resource->tensor_array_gradients) { + gradient.second->initial_value = gradient.second->value; + } break; + } case XlaCompiler::Argument::kParameter: arg_expression.set_handle(arg_handles[i]); break; @@ -341,6 +355,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args, // index of a resource variable argument to the computation, and `type` is the // type of the final output. Status BuildComputation( + const std::vector<XlaCompiler::Argument>& args, const std::vector<XlaExpression>& retvals, const std::vector<std::unique_ptr<XlaResource>>& resources, bool has_side_effects, bool return_updated_values_for_all_resources, @@ -357,27 +372,42 @@ Status BuildComputation( *num_nonconst_outputs = elems.size(); // Add return values for resources whose values have changed. - std::vector<const XlaResource*> arg_vars; - arg_vars.reserve(resources.size()); - for (const auto& var : resources) { - if (var->arg_num >= 0) { - arg_vars.push_back(var.get()); + std::vector<const XlaResource*> arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num >= 0) { + arg_resources.push_back(resource.get()); } } - std::sort(arg_vars.begin(), arg_vars.end(), + std::sort(arg_resources.begin(), arg_resources.end(), [](const XlaResource* a, const XlaResource* b) { return a->arg_num < b->arg_num; }); - for (const XlaResource* var : arg_vars) { - bool modified = var->value.handle() != var->initial_value.handle(); + for (const XlaResource* resource : arg_resources) { + const XlaCompiler::Argument& arg = args[resource->arg_num]; + bool modified = + resource->value.handle() != resource->initial_value.handle(); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients) { + modified = + modified || + grad.second->value.handle() != grad.second->initial_value.handle() || + arg.tensor_array_gradients.count(grad.first) == 0; + } if (return_updated_values_for_all_resources || modified) { resource_updates->emplace_back(); XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = var->arg_num; - update.type = var->type; + update.input_index = resource->arg_num; + update.type = resource->type; update.modified = modified; - elems.push_back(var->value); + for (const auto& grad : resource->tensor_array_gradients) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + xla::ComputationDataHandle handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + elems.push_back(handle); } } @@ -453,7 +483,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared<xla::Computation>(); TF_RETURN_IF_ERROR(BuildComputation( - context->retvals(), context->resources(), context->has_side_effects(), + args, context->retvals(), context->resources(), + context->has_side_effects(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 809f668dd2..cf78e2cc13 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -44,14 +44,14 @@ namespace tensorflow { // // The XlaCompiler requires one Argument struct for each _Arg index, that // describes each argument. Arguments can be compile-time constants -// (kind kConstant), run-time parameters (kind kParameter), or resource -// variables (kinds kVariable and kUninitializedVariable). +// (kind kConstant), run-time parameters (kind kParameter), or resources +// (kind kResource). // -// Only kParameter and kVariable arguments become runtime parameters to the -// generated XLA computation. The XLA computation will have run-time parameters -// in the following order: +// Only kParameter and initialized kResource arguments become runtime parameters +// to the generated XLA computation. The XLA computation will have run-time +// parameters in the following order: // +---------------------+-----------------------------------------+ -// | kParameter values | Initial values of kVariable arguments | +// | kParameter values | Initial values of kResource arguments | // +---------------------+-----------------------------------------+ // Within each block, the arguments are arranged by the _Arg index from which // they were derived. @@ -61,18 +61,26 @@ namespace tensorflow { // The run-time outputs of the XLA computation are arranged in the following // order: // +------------------+-----------------------------------------+ -// | _Retval values | Updated values of kVariable arguments | +// | _Retval values | Updated values of kResource arguments | // +------------------+-----------------------------------------+ -// _Retval values are ordered by _Retval index, whereas kVariable values are +// _Retval values are ordered by _Retval index, whereas kResource values are // ordered by the original _Arg position of the variable. // -// In both inputs and outputs, kVariable values are placed the end. When +// In both inputs and outputs, kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has // identical input and output signatures. By moving variable values // to the end of the argument list and using the // `return_updated_values_for_all_variables` option, we can ensure that the -// input and output values of variables appear at the same positions. - +// input and output values of resources appear at the same positions. +// +// Resources are passed as parameters or returned as resource updates in +// "packed" form. +// kStack resources are packed as (array, size of stack) XLA tuples. +// kTensorArray resources without gradients are packed as the array that +// backs the TensorArray. If gradients are present (`tensor_array_gradients`), +// the packed representation is a (array, gradient0, gradient1, ...) tuple, +// where gradient_k is the value of the k-th gradient in the +// `tensor_array_gradients` ordered set. class XlaCompiler { public: // Describes how to derive the value of each _Arg node in the graph/function @@ -120,6 +128,11 @@ class XlaCompiler { // (Used for lazy initialization.) int64 tensor_array_size = -1; + // TensorArray resource parameters are passed as (array, gradient array 0, + // ..., gradient array k), where the gradient arrays are in the same order + // as `tensor_array_gradients`. + std::set<string> tensor_array_gradients; + bool operator==(const Argument& other) const; }; @@ -146,6 +159,9 @@ class XlaCompiler { // Was the value of the variable modified by the computation? // (Always true, unless `return_updated_values_for_all_resources` is true.) bool modified; + + // If the resource is a TensorArray, the set of gradients read or written. + std::set<string> tensor_array_gradients_accessed; }; struct CompilationResult { diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index aa8df80d34..f516dd867a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -349,5 +350,145 @@ TEST_F(XlaCompilerTest, ResourceManager) { resource->Unref(); } +// Tests a computation that receives a TensorArray resource as input and +// updates it. +TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const<float>(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); + auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2"); + auto index = ops::Const<int32>(scope, 1); + auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index, + grad2.flow_out); + auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad2"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + ASSERT_EQ(1, result.resource_updates.size()); + const XlaCompiler::ResourceUpdate& update = result.resource_updates[0]; + EXPECT_EQ(0, update.input_index); + EXPECT_EQ(DT_INT32, update.type); + EXPECT_EQ((std::set<string>{"grad1", "grad2"}), + update.tensor_array_gradients_accessed); + + // Tests that the generated computation works. + std::unique_ptr<xla::Literal> input_base = + xla::Literal::CreateR1<int32>({7, 42}); + std::unique_ptr<xla::Literal> input_grad2 = + xla::Literal::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::Literal> input = + xla::Literal::MakeTuple({input_base.get(), input_grad2.get()}); + std::unique_ptr<xla::GlobalData> param0_data = + client_->TransferToServer(*input).ConsumeValueOrDie(); + + std::unique_ptr<xla::GlobalData> actual = + client_->Execute(*result.computation, {param0_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr<xla::Literal> actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42); + std::unique_ptr<xla::Literal> output_base = + xla::Literal::CreateR1<int32>({7, 42}); + std::unique_ptr<xla::Literal> output_grad1 = + xla::Literal::CreateR1<int32>({0, 1}); + std::unique_ptr<xla::Literal> output_grad2 = + xla::Literal::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple( + {output_base.get(), output_grad1.get(), output_grad2.get()}); + std::unique_ptr<xla::Literal> expected_literal = + xla::Literal::MakeTuple({output_read.get(), output_resource.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const<float>(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1"); + auto index = ops::Const<int32>(scope, 1); + auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad1"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + EXPECT_EQ(0, result.resource_updates.size()); +} + +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0); + auto flow = ops::Const<float>(scope, {}); + auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2"); + auto index = ops::Const<int32>(scope, 1); + auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32); + auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(1); + args[0].kind = XlaCompiler::Argument::kResource; + args[0].resource_kind = XlaResource::kTensorArray; + args[0].initialized = true; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::S32, {2}), + xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].tensor_array_size = 2; + args[0].tensor_array_gradients = {"grad1"}; + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + EXPECT_EQ(1, result.resource_updates.size()); +} + } // namespace } // namespace tensorflow |