aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tf2xla
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-09-19 19:08:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-19 19:12:43 -0700
commit1f20a786d69c4b91a4015fe3f4df8c23bd345f40 (patch)
tree9175a24a490a21587cb899b4dda98f11f83f948c /tensorflow/compiler/tf2xla
parent5ce3523bcc844217b47e7f862c1bed894cbaa34e (diff)
[TF:XLA] Add support for reading and writing TensorArray gradients in a while loop.
Previously, there was no code to handle propagating the values of a TensorArray's gradients into and out of loops. This change passes TensorArray gradients into and out of loops by packing them up as a (base array, gradient values...) tuple. PiperOrigin-RevId: 169338418
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r--tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc24
-rw-r--r--tensorflow/compiler/tf2xla/kernels/while_op.cc97
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.cc100
-rw-r--r--tensorflow/compiler/tf2xla/xla_compilation_device.h43
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc65
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h38
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc141
7 files changed, 428 insertions, 80 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
index 7f1597e9ad..c42d8b97ea 100644
--- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -114,12 +115,7 @@ Status CheckTensorArrayIsInitialized(const string& op_name,
Status GetTensorArrayShape(const XlaResource* resource,
xla::ComputationBuilder* builder,
TensorShape* shape) {
- auto shape_or_status = builder->GetShape(resource->value);
- if (!shape_or_status.ok()) {
- return shape_or_status.status();
- }
- TF_RETURN_IF_ERROR(
- XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape));
+ TF_RETURN_IF_ERROR(resource->GetShape(builder, shape));
if (shape->dims() < 1) {
return errors::InvalidArgument("TensorArray rank must be >= 1");
}
@@ -532,19 +528,9 @@ class TensorArrayGradOp : public XlaOpKernel {
// Finds or looks up the corresponding gradient TensorArray, which stores
// gradients computed during backpropagation.
- XlaResource*& gradient = resource->tensor_array_gradient[source_];
- if (!gradient) {
- xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type);
- xla::ComputationDataHandle value =
- b->Broadcast(zero, ta_shape.dim_sizes());
-
- XlaContext& xc = XlaContext::Get(ctx);
- string name = strings::StrCat("TensorArrayGrad: ", resource->name);
- OP_REQUIRES_OK(
- ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name),
- resource->type, value, &gradient));
- gradient->tensor_array_size = resource->tensor_array_size;
- }
+ XlaResource* gradient;
+ OP_REQUIRES_OK(
+ ctx, resource->GetOrCreateTensorArrayGradient(source_, b, &gradient));
ctx->SetResourceOutput(0, gradient);
ctx->SetConstantOutput(1, Tensor(DT_FLOAT));
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index 55995aa421..ead26478ff 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -33,10 +33,11 @@ namespace {
// Builds XlaCompiler argument descriptions `args` from `ctx`.
Status MakeXlaCompilerArgumentsFromInputs(
XlaOpKernelContext* ctx, std::vector<XlaCompiler::Argument>* args,
- bool* has_uninitialized_vars) {
+ bool* has_uninitialized_vars, bool* has_tensor_arrays) {
VLOG(2) << "Num inputs " << ctx->num_inputs();
args->resize(ctx->num_inputs());
*has_uninitialized_vars = false;
+ *has_tensor_arrays = false;
for (int i = 0; i < ctx->num_inputs(); ++i) {
VLOG(2) << " Input " << i
<< " type: " << DataTypeString(ctx->input_type(i))
@@ -52,20 +53,24 @@ Status MakeXlaCompilerArgumentsFromInputs(
arg.initialized = resource->value.handle() > 0;
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = resource->kind;
+ if (arg.resource_kind == XlaResource::kTensorArray) {
+ *has_tensor_arrays = true;
+ }
+
arg.type = resource->type;
if (arg.initialized) {
- auto shape = ctx->builder()->GetShape(resource->value);
- TF_RETURN_IF_ERROR(shape.status());
- arg.shape = *shape.ValueOrDie();
+ TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape));
} else {
*has_uninitialized_vars = true;
}
arg.tensor_array_size = resource->tensor_array_size;
+ for (const auto& gradient : resource->tensor_array_gradients) {
+ arg.tensor_array_gradients.insert(gradient.first);
+ }
arg.name = resource->name;
- // TODO(phawkins): propagate TensorArray gradients into loops.
VLOG(2) << " resource " << resource->name
<< " type: " << DataTypeString(arg.type)
- << " shape: " << arg.shape.DebugString()
+ << " shape: " << xla::ShapeUtil::HumanString(arg.shape)
<< " initialized: " << arg.initialized;
} else {
@@ -93,8 +98,10 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
std::vector<XlaCompiler::Argument> arguments;
bool has_uninitialized_vars;
- OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs(
- ctx, &arguments, &has_uninitialized_vars));
+ bool has_tensor_arrays;
+ OP_REQUIRES_OK(
+ ctx, MakeXlaCompilerArgumentsFromInputs(
+ ctx, &arguments, &has_uninitialized_vars, &has_tensor_arrays));
xla::ComputationBuilder* builder = ctx->builder();
XlaCompiler* compiler = ctx->compiler();
@@ -118,38 +125,67 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
arguments, &body));
// We must use a static shape for parameters to an XLA compilation. However,
- // we may not know the shape of a TensorArray if it is first written inside
- // the loop. Ideally we would require the user to provide a static shape,
- // but this is not always easy.
- // So if uninitialized resource are used by the loop body, we compile the
- // body function twice:
- // 1) once with uninitialized resource inputs. We discard the computation
- // but we assume resource shapes reach a fixpoint after one iteration.
- // So we can use the output shapes of the resource as the "true" shapes.
- // 2) again with the "correct" input shapes determined by (1).
- if (has_uninitialized_vars) {
+ // we may not know the shape of a resource if it is first
+ // written inside the loop. Furthermore, we do not know ahead of time which
+ // gradient TensorArrays will be created by the TensorArrayGradV3 operator.
+ //
+ // Ideally we would change TensorFlow to provide static shape always, but
+ // but this is not easy to do. So if uninitialized resources or TensorArrays
+ // are used by the loop body, we compile the body function twice:
+ // 1) once with uninitialized resource inputs and no TensorArray gradient
+ // inputs. We then discard the computation but we assume resource shapes
+ // and the set of gradients read or written will reach a fixpoint after one
+ // iteration.
+ // Hence we can use the output shapes and TensorArray gradients of each
+ // resource as the "true" shapes.
+ // 2) again with the "correct" resource information determined by (1).
+ if (has_uninitialized_vars || has_tensor_arrays) {
+ VLOG(2) << "Recompiling loop body: has_uninitialized_vars: "
+ << has_uninitialized_vars
+ << " has_tensor_arrays: " << has_tensor_arrays;
// Initializes any uninitialized resource with zero values of the
// shape determined by the first compilation.
for (int i = 0; i < body.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& update = body.resource_updates[i];
+ XlaResource* resource;
+ OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
+
XlaCompiler::Argument& arg = arguments[update.input_index];
if (!arg.initialized) {
VLOG(2) << "Update shape for argument " << update.input_index << " "
<< xla::ShapeUtil::HumanString(update.shape);
arg.initialized = true;
- arg.shape = update.shape;
-
- XlaResource* resource;
- OP_REQUIRES_OK(ctx,
- ctx->GetResourceInput(update.input_index, &resource));
+ xla::Shape shape = update.shape;
+ if (!update.tensor_array_gradients_accessed.empty()) {
+ shape = xla::ShapeUtil::GetTupleElementShape(shape, 0);
+ }
std::unique_ptr<xla::Literal> zero =
- xla::Literal::CreateFromShape(update.shape);
+ xla::Literal::CreateFromShape(shape);
resource->value = builder->ConstantLiteral(*zero);
}
+
+ // Add any TensorArray gradients touched by the body to the enclosing
+ // graph.
+ for (const string& grad_source : update.tensor_array_gradients_accessed) {
+ VLOG(4) << "TensorArray " << resource->name << " accessed gradient "
+ << grad_source;
+ XlaResource* gradient;
+ OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
+ grad_source, builder, &gradient));
+ }
+
+ // Add all of the TensorArray gradients to the argument. For simplicity,
+ // we always pass all known gradients.
+ for (const auto& gradient : resource->tensor_array_gradients) {
+ arg.tensor_array_gradients.insert(gradient.first);
+ }
+
+ // Recompute the argument shape.
+ OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape));
}
- // Recompile the body with the "correct" shapes.
- VLOG(1) << "Recompiling body with non-placeholder shapes";
+ // Recompile the body with the "correct" resource shapes.
+ VLOG(1) << "Recompiling body with corrected resource shapes";
body = {};
OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_,
arguments, &body));
@@ -203,7 +239,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
if (ctx->input_type(input_num) == DT_RESOURCE) {
XlaResource* resource;
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
- inputs[i] = resource->value;
+ OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], builder));
} else {
inputs[i] = ctx->Input(i);
}
@@ -244,12 +280,15 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) {
OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource));
if (update.modified) {
int pos = body.outputs.size() + i;
- resource->value = builder->GetTupleElement(while_result, pos);
+ OP_REQUIRES_OK(ctx,
+ resource->SetFromPack(
+ arguments[update.input_index].tensor_array_gradients,
+ builder->GetTupleElement(while_result, pos), builder));
}
VLOG(2) << "Loop-carried variable: pos: " << update.input_index
<< " name: " << resource->name << " modified: " << update.modified
<< " type: " << DataTypeString(update.type)
- << " shape: " << update.shape.DebugString();
+ << " shape: " << xla::ShapeUtil::HumanString(update.shape);
// Copies the identity of the resource variable from input to output
// unchanged, even if the variable was not modified.
ctx->op_kernel_context()->set_output(
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index 1d0098591e..4e6ef489f6 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -18,7 +18,9 @@ limitations under the License.
#include <functional>
#include <memory>
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/platform/mem.h"
@@ -87,7 +89,7 @@ Allocator* XlaCompilationDevice::GetAllocator(AllocatorAttributes attr) {
void XlaCompilationDevice::Compute(OpKernel* op_kernel,
OpKernelContext* context) {
- VLOG(1) << "XlaCompilationDevice::Compute "
+ VLOG(4) << "XlaCompilationDevice::Compute "
<< SummarizeNodeDef(op_kernel->def());
auto* b = XlaContext::Get(context).builder();
xla::OpMetadata metadata;
@@ -96,7 +98,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel,
b->SetOpMetadata(metadata);
op_kernel->Compute(context);
b->ClearOpMetadata();
- VLOG(2) << "Done";
+ VLOG(4) << "Done";
}
Status XlaCompilationDevice::Sync() { return Status::OK(); }
@@ -119,4 +121,98 @@ void XlaExpression::set_constant_value(Tensor value) {
constant_value_ = std::move(value);
}
+Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder,
+ xla::Shape* shape) const {
+ auto shape_or_status = builder->GetShape(value);
+ if (!shape_or_status.ok()) {
+ return shape_or_status.status();
+ }
+ *shape = *shape_or_status.ValueOrDie();
+ return Status::OK();
+}
+
+Status XlaResource::GetShape(xla::ComputationBuilder* builder,
+ TensorShape* shape) const {
+ xla::Shape xla_shape;
+ TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape));
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape));
+ return Status::OK();
+}
+
+Status XlaResource::GetOrCreateTensorArrayGradient(
+ const string& source, xla::ComputationBuilder* builder,
+ XlaResource** gradient_out) {
+ VLOG(2) << "Gradient lookup for resource: " << name
+ << " gradient: " << source;
+ TF_RET_CHECK(kind == kTensorArray);
+ std::unique_ptr<XlaResource>& gradient = tensor_array_gradients[source];
+ if (!gradient) {
+ gradient.reset(new XlaResource);
+ gradient->kind = XlaResource::kTensorArray;
+ gradient->name = strings::StrCat("TensorArrayGrad: ", name);
+ gradient->type = type;
+ gradient->tensor_array_size = tensor_array_size;
+
+ TensorShape ta_shape;
+ TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape));
+ gradient->value = builder->Broadcast(XlaHelpers::Zero(builder, type),
+ ta_shape.dim_sizes());
+ gradient->initial_value = gradient->value;
+ }
+ *gradient_out = gradient.get();
+ return Status::OK();
+}
+
+Status XlaResource::PackedShape(xla::ComputationBuilder* builder,
+ xla::Shape* packed_shape) const {
+ if (tensor_array_gradients.empty()) {
+ return GetXlaShape(builder, packed_shape);
+ }
+ TF_RET_CHECK(kind == kTensorArray);
+ std::vector<xla::Shape> elem_shapes(1 + tensor_array_gradients.size());
+ int pos = 0;
+ TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++]));
+ for (const auto& gradient : tensor_array_gradients) {
+ TF_RETURN_IF_ERROR(
+ gradient.second->GetXlaShape(builder, &elem_shapes[pos++]));
+ }
+ *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes);
+ return Status::OK();
+}
+
+Status XlaResource::Pack(xla::ComputationDataHandle* pack,
+ xla::ComputationBuilder* builder) const {
+ if (tensor_array_gradients.empty()) {
+ *pack = value;
+ } else {
+ TF_RET_CHECK(kind == kTensorArray);
+ std::vector<xla::ComputationDataHandle> elems;
+ elems.push_back(value);
+ for (const auto& gradient : tensor_array_gradients) {
+ elems.push_back(gradient.second->value);
+ }
+ *pack = builder->Tuple(elems);
+ }
+ return Status::OK();
+}
+
+Status XlaResource::SetFromPack(const std::set<string>& gradient_sources,
+ const xla::ComputationDataHandle& pack,
+ xla::ComputationBuilder* builder) {
+ if (gradient_sources.empty()) {
+ value = pack;
+ } else {
+ TF_RET_CHECK(kind == kTensorArray);
+ int pos = 0;
+ value = builder->GetTupleElement(pack, pos++);
+ for (const auto& source : gradient_sources) {
+ XlaResource* gradient;
+ TF_RETURN_IF_ERROR(
+ GetOrCreateTensorArrayGradient(source, builder, &gradient));
+ gradient->value = builder->GetTupleElement(pack, pos++);
+ }
+ }
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h
index 22c24f4963..765683cf1d 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.h
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/framework/device_base.h"
@@ -65,6 +66,7 @@ class XlaCompilationDevice : public LocalDevice {
};
// Represents a resource, such as a Variable or TensorArray.
+// TODO(phawkins): make this into a properly abstracted class.
struct XlaResource {
enum Kind {
kInvalid,
@@ -103,8 +105,45 @@ struct XlaResource {
// 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes
// to an XlaResource containing the gradient TensorArrays. We store a pointer
// here since there should only be one gradient TensorArray per 'source'
- // string, irrespective of the number of calls to TensorArrayGrad.
- std::unordered_map<string, XlaResource*> tensor_array_gradient;
+ // string, irrespective of the number of calls to TensorArrayGrad. The map
+ // is ordered since values are packed into tuples by Pack() sorted by name
+ // order.
+ std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients;
+
+ // Returns the shape of the resource as an xla::Shape.
+ Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const;
+
+ // Returns the shape of the resource as an TensorShape. Fails if the shape is
+ // not representable as a TensorShape.
+ Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const;
+
+ // Looks up the gradient for `source`, or creates it if it does not already
+ // exist. The call target must be an initialized TensorArray resource. A
+ // TensorArray can have multiple named gradients; see the operator
+ // documentation for TensorArrayGradV3 for details.
+ Status GetOrCreateTensorArrayGradient(const string& source,
+ xla::ComputationBuilder* builder,
+ XlaResource** gradient_out);
+
+ // Packs a resource into a single XLA value `pack`, suitable for use as
+ // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without
+ // gradients, sets `*pack` to `value`.
+ // For TensorArrays with gradients, packs the value and its gradient values in
+ // a tuple; the gradients values are packed in order by source name.
+ Status Pack(xla::ComputationDataHandle* pack,
+ xla::ComputationBuilder* builder) const;
+
+ // Returns the shape of the `pack` value computed by `Pack()`.
+ Status PackedShape(xla::ComputationBuilder* builder,
+ xla::Shape* packed_shape) const;
+
+ // Updates the resource with values from `pack`. If `gradient_sources` is
+ // non-empty, treats `pack` as a tuple that represents a TensorArray and
+ // its gradients, and unpacks and updates the gradient resources. Opposite
+ // of Pack().
+ Status SetFromPack(const std::set<string>& gradient_sources,
+ const xla::ComputationDataHandle& pack,
+ xla::ComputationBuilder* builder);
};
// A XlaExpression wraps an XLA computation. Each Tensor on an
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 08b9faad4a..34b1246be2 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -60,8 +60,10 @@ Status CheckSignature(const DataTypeVector& types,
bool XlaCompiler::Argument::operator==(
const XlaCompiler::Argument& other) const {
- if (std::tie(kind, type, name, tensor_array_size) !=
- std::tie(other.kind, other.type, other.name, other.tensor_array_size)) {
+ if (std::tie(kind, resource_kind, type, name, tensor_array_size,
+ tensor_array_gradients) !=
+ std::tie(other.kind, other.resource_kind, other.type, other.name,
+ other.tensor_array_size, other.tensor_array_gradients)) {
return false;
}
if (!xla::ShapeUtil::Equal(shape, other.shape)) {
@@ -303,15 +305,27 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
}
// Fill in the handles in non-constant arguments.
+ VLOG(2) << "XLA computation inputs:";
for (std::vector<int>::size_type i = 0; i < parameters.size(); ++i) {
const XlaCompiler::Argument& arg = args[parameters[i]];
+ VLOG(2) << " XLA arg " << i
+ << " shape: " << xla::ShapeUtil::HumanString((*input_shapes)[i])
+ << " name: " << arg.name << " TF arg " << parameters[i];
XlaExpression& arg_expression = (*arg_expressions)[parameters[i]];
switch (arg.kind) {
- case XlaCompiler::Argument::kResource:
+ case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
- arg_expression.resource()->value = arg_handles[i];
- arg_expression.resource()->initial_value = arg_handles[i];
+ XlaResource* resource = arg_expression.resource();
+ TF_RETURN_IF_ERROR(resource->SetFromPack(arg.tensor_array_gradients,
+ arg_handles[i], builder));
+ VLOG(2) << " resource: num_gradients: "
+ << arg.tensor_array_gradients.size();
+ resource->initial_value = resource->value;
+ for (const auto& gradient : resource->tensor_array_gradients) {
+ gradient.second->initial_value = gradient.second->value;
+ }
break;
+ }
case XlaCompiler::Argument::kParameter:
arg_expression.set_handle(arg_handles[i]);
break;
@@ -341,6 +355,7 @@ Status BuildArguments(const std::vector<XlaCompiler::Argument>& args,
// index of a resource variable argument to the computation, and `type` is the
// type of the final output.
Status BuildComputation(
+ const std::vector<XlaCompiler::Argument>& args,
const std::vector<XlaExpression>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
bool has_side_effects, bool return_updated_values_for_all_resources,
@@ -357,27 +372,42 @@ Status BuildComputation(
*num_nonconst_outputs = elems.size();
// Add return values for resources whose values have changed.
- std::vector<const XlaResource*> arg_vars;
- arg_vars.reserve(resources.size());
- for (const auto& var : resources) {
- if (var->arg_num >= 0) {
- arg_vars.push_back(var.get());
+ std::vector<const XlaResource*> arg_resources;
+ arg_resources.reserve(resources.size());
+ for (const auto& resource : resources) {
+ if (resource->arg_num >= 0) {
+ arg_resources.push_back(resource.get());
}
}
- std::sort(arg_vars.begin(), arg_vars.end(),
+ std::sort(arg_resources.begin(), arg_resources.end(),
[](const XlaResource* a, const XlaResource* b) {
return a->arg_num < b->arg_num;
});
- for (const XlaResource* var : arg_vars) {
- bool modified = var->value.handle() != var->initial_value.handle();
+ for (const XlaResource* resource : arg_resources) {
+ const XlaCompiler::Argument& arg = args[resource->arg_num];
+ bool modified =
+ resource->value.handle() != resource->initial_value.handle();
+ // TensorArray gradients were modified if their values changed or there are
+ // any newly created gradients.
+ for (const auto& grad : resource->tensor_array_gradients) {
+ modified =
+ modified ||
+ grad.second->value.handle() != grad.second->initial_value.handle() ||
+ arg.tensor_array_gradients.count(grad.first) == 0;
+ }
if (return_updated_values_for_all_resources || modified) {
resource_updates->emplace_back();
XlaCompiler::ResourceUpdate& update = resource_updates->back();
- update.input_index = var->arg_num;
- update.type = var->type;
+ update.input_index = resource->arg_num;
+ update.type = resource->type;
update.modified = modified;
- elems.push_back(var->value);
+ for (const auto& grad : resource->tensor_array_gradients) {
+ update.tensor_array_gradients_accessed.insert(grad.first);
+ }
+ xla::ComputationDataHandle handle;
+ TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
+ elems.push_back(handle);
}
}
@@ -453,7 +483,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_computation_outputs;
result->computation = std::make_shared<xla::Computation>();
TF_RETURN_IF_ERROR(BuildComputation(
- context->retvals(), context->resources(), context->has_side_effects(),
+ args, context->retvals(), context->resources(),
+ context->has_side_effects(),
options.return_updated_values_for_all_resources, &builder,
result->computation.get(), &num_computation_outputs,
&num_nonconst_outputs, &result->resource_updates));
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 809f668dd2..cf78e2cc13 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -44,14 +44,14 @@ namespace tensorflow {
//
// The XlaCompiler requires one Argument struct for each _Arg index, that
// describes each argument. Arguments can be compile-time constants
-// (kind kConstant), run-time parameters (kind kParameter), or resource
-// variables (kinds kVariable and kUninitializedVariable).
+// (kind kConstant), run-time parameters (kind kParameter), or resources
+// (kind kResource).
//
-// Only kParameter and kVariable arguments become runtime parameters to the
-// generated XLA computation. The XLA computation will have run-time parameters
-// in the following order:
+// Only kParameter and initialized kResource arguments become runtime parameters
+// to the generated XLA computation. The XLA computation will have run-time
+// parameters in the following order:
// +---------------------+-----------------------------------------+
-// | kParameter values | Initial values of kVariable arguments |
+// | kParameter values | Initial values of kResource arguments |
// +---------------------+-----------------------------------------+
// Within each block, the arguments are arranged by the _Arg index from which
// they were derived.
@@ -61,18 +61,26 @@ namespace tensorflow {
// The run-time outputs of the XLA computation are arranged in the following
// order:
// +------------------+-----------------------------------------+
-// | _Retval values | Updated values of kVariable arguments |
+// | _Retval values | Updated values of kResource arguments |
// +------------------+-----------------------------------------+
-// _Retval values are ordered by _Retval index, whereas kVariable values are
+// _Retval values are ordered by _Retval index, whereas kResource values are
// ordered by the original _Arg position of the variable.
//
-// In both inputs and outputs, kVariable values are placed the end. When
+// In both inputs and outputs, kResource values are placed the end. When
// emitting While loop bodies, we must ensure that the loop body has
// identical input and output signatures. By moving variable values
// to the end of the argument list and using the
// `return_updated_values_for_all_variables` option, we can ensure that the
-// input and output values of variables appear at the same positions.
-
+// input and output values of resources appear at the same positions.
+//
+// Resources are passed as parameters or returned as resource updates in
+// "packed" form.
+// kStack resources are packed as (array, size of stack) XLA tuples.
+// kTensorArray resources without gradients are packed as the array that
+// backs the TensorArray. If gradients are present (`tensor_array_gradients`),
+// the packed representation is a (array, gradient0, gradient1, ...) tuple,
+// where gradient_k is the value of the k-th gradient in the
+// `tensor_array_gradients` ordered set.
class XlaCompiler {
public:
// Describes how to derive the value of each _Arg node in the graph/function
@@ -120,6 +128,11 @@ class XlaCompiler {
// (Used for lazy initialization.)
int64 tensor_array_size = -1;
+ // TensorArray resource parameters are passed as (array, gradient array 0,
+ // ..., gradient array k), where the gradient arrays are in the same order
+ // as `tensor_array_gradients`.
+ std::set<string> tensor_array_gradients;
+
bool operator==(const Argument& other) const;
};
@@ -146,6 +159,9 @@ class XlaCompiler {
// Was the value of the variable modified by the computation?
// (Always true, unless `return_updated_values_for_all_resources` is true.)
bool modified;
+
+ // If the resource is a TensorArray, the set of gradients read or written.
+ std::set<string> tensor_array_gradients_accessed;
};
struct CompilationResult {
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index aa8df80d34..f516dd867a 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
@@ -349,5 +350,145 @@ TEST_F(XlaCompilerTest, ResourceManager) {
resource->Unref();
}
+// Tests a computation that receives a TensorArray resource as input and
+// updates it.
+TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
+ auto flow = ops::Const<float>(scope, {});
+ auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
+ auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
+ auto index = ops::Const<int32>(scope, 1);
+ auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
+ grad2.flow_out);
+ auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
+ auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(1);
+ args[0].kind = XlaCompiler::Argument::kResource;
+ args[0].resource_kind = XlaResource::kTensorArray;
+ args[0].initialized = true;
+ args[0].type = DT_INT32;
+ args[0].shape = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::S32, {2}),
+ xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].tensor_array_size = 2;
+ args[0].tensor_array_gradients = {"grad2"};
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+
+ ASSERT_EQ(1, result.resource_updates.size());
+ const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
+ EXPECT_EQ(0, update.input_index);
+ EXPECT_EQ(DT_INT32, update.type);
+ EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
+ update.tensor_array_gradients_accessed);
+
+ // Tests that the generated computation works.
+ std::unique_ptr<xla::Literal> input_base =
+ xla::Literal::CreateR1<int32>({7, 42});
+ std::unique_ptr<xla::Literal> input_grad2 =
+ xla::Literal::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::Literal> input =
+ xla::Literal::MakeTuple({input_base.get(), input_grad2.get()});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client_->TransferToServer(*input).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::GlobalData> actual =
+ client_->Execute(*result.computation, {param0_data.get()})
+ .ConsumeValueOrDie();
+ std::unique_ptr<xla::Literal> actual_literal =
+ client_->Transfer(*actual).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::Literal> output_read = xla::Literal::CreateR0<int32>(42);
+ std::unique_ptr<xla::Literal> output_base =
+ xla::Literal::CreateR1<int32>({7, 42});
+ std::unique_ptr<xla::Literal> output_grad1 =
+ xla::Literal::CreateR1<int32>({0, 1});
+ std::unique_ptr<xla::Literal> output_grad2 =
+ xla::Literal::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::Literal> output_resource = xla::Literal::MakeTuple(
+ {output_base.get(), output_grad1.get(), output_grad2.get()});
+ std::unique_ptr<xla::Literal> expected_literal =
+ xla::Literal::MakeTuple({output_read.get(), output_resource.get()});
+ xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+}
+
+// Tests compilation and execution of a graph that adds two tensors.
+TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
+ auto flow = ops::Const<float>(scope, {});
+ auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
+ auto index = ops::Const<int32>(scope, 1);
+ auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
+ auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(1);
+ args[0].kind = XlaCompiler::Argument::kResource;
+ args[0].resource_kind = XlaResource::kTensorArray;
+ args[0].initialized = true;
+ args[0].type = DT_INT32;
+ args[0].shape = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::S32, {2}),
+ xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].tensor_array_size = 2;
+ args[0].tensor_array_gradients = {"grad1"};
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+
+ EXPECT_EQ(0, result.resource_updates.size());
+}
+
+// Tests compilation and execution of a graph that adds two tensors.
+TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
+ auto flow = ops::Const<float>(scope, {});
+ auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
+ auto index = ops::Const<int32>(scope, 1);
+ auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
+ auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(1);
+ args[0].kind = XlaCompiler::Argument::kResource;
+ args[0].resource_kind = XlaResource::kTensorArray;
+ args[0].initialized = true;
+ args[0].type = DT_INT32;
+ args[0].shape = xla::ShapeUtil::MakeTupleShape(
+ {xla::ShapeUtil::MakeShape(xla::S32, {2}),
+ xla::ShapeUtil::MakeShape(xla::S32, {2})});
+ args[0].tensor_array_size = 2;
+ args[0].tensor_array_gradients = {"grad1"};
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+
+ EXPECT_EQ(1, result.resource_updates.size());
+}
+
} // namespace
} // namespace tensorflow