aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-02-14 14:49:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-14 14:53:40 -0800
commitaa92995493939b674cd54b13b5850cc185a6e7ae (patch)
tree5ce3c336a1ca4d1241af5060de7869a64ba48c30 /tensorflow/compiler
parent90159c53b83527bd088452769ea4e1b98667860c (diff)
[TF:XLA] Add a hook to allow reshaping of TensorFlow variables when storing them in their XLA representation.
PiperOrigin-RevId: 185748660
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.cc24
-rw-r--r--tensorflow/compiler/tf2xla/literal_util.h9
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc243
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h22
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc124
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc16
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h14
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc22
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h2
10 files changed, 341 insertions, 136 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 3c7dfef03d..fb82c2601c 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -312,6 +312,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:ops",
+ "//tensorflow/cc:resource_variable_ops",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc
index fcbd157c61..2c3cd658e0 100644
--- a/tensorflow/compiler/tf2xla/literal_util.cc
+++ b/tensorflow/compiler/tf2xla/literal_util.cc
@@ -40,20 +40,20 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) {
return Status::OK();
}
-Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
- Tensor* host_tensor) {
+Status CopyLiteralToHostTensor(const xla::Literal& literal,
+ Tensor* host_tensor) {
+ TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) &&
+ xla::ShapeUtil::ElementsIn(literal.shape()) ==
+ host_tensor->NumElements());
xla::PrimitiveType primitive_type;
- TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(target_type, &primitive_type));
+ TF_RETURN_IF_ERROR(
+ DataTypeToPrimitiveType(host_tensor->dtype(), &primitive_type));
if (literal.shape().element_type() != primitive_type) {
return errors::InvalidArgument(
"Cannot convert literal of type ",
xla::PrimitiveType_Name(literal.shape().element_type()),
- " to tensor of type ", DataTypeString(target_type));
+ " to tensor of type ", DataTypeString(host_tensor->dtype()));
}
-
- TensorShape shape;
- TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape));
- *host_tensor = Tensor(target_type, shape);
size_t total_bytes = host_tensor->TotalBytes();
if (total_bytes > 0) {
const void* src_ptr = literal.untyped_data();
@@ -63,4 +63,12 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
return Status::OK();
}
+Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
+ Tensor* host_tensor) {
+ TensorShape shape;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape));
+ *host_tensor = Tensor(target_type, shape);
+ return CopyLiteralToHostTensor(literal, host_tensor);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h
index fe08e83c23..f283b02368 100644
--- a/tensorflow/compiler/tf2xla/literal_util.h
+++ b/tensorflow/compiler/tf2xla/literal_util.h
@@ -29,7 +29,8 @@ namespace tensorflow {
// unsupported type.
Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
-// Copies 'literal' to 'host_tensor', which is allocated of type <target_type>.
+// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of
+// type <target_type>.
// Fails if the literal's primitive type !=
// DataTypeToPrimitiveType(target_type). Note that <target_type> is not
// derivable from the type of <literal>, because multiple tensorflow types map
@@ -38,6 +39,12 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal);
Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type,
Tensor* host_tensor);
+// Copies the contents of 'literal' to a previously allocated tensor
+// 'host_tensor'. The tensor and the literal must have the same number of
+// elements and the same type.
+Status CopyLiteralToHostTensor(const xla::Literal& literal,
+ Tensor* host_tensor);
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 59e8830442..15bba46ac6 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -109,6 +109,12 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
local_flib_runtime_ = local_pflr_->GetFLR(device_->name());
flib_runtime_ = pflr_->GetFLR(device_->name());
+
+ // The default variable representation shape is the identity function.
+ if (!options_.variable_representation_shape_fn) {
+ options_.variable_representation_shape_fn =
+ [](const TensorShape& shape, DataType type) { return shape; };
+ }
}
XlaCompiler::~XlaCompiler() = default;
@@ -223,8 +229,8 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
}
// Computes the XLA shape for argument 'arg'.
-/*static*/ Status XlaCompiler::XLAShapeForArgument(
- const XlaCompiler::Argument& arg, xla::Shape* xla_shape) {
+Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
+ xla::Shape* xla_shape) {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
@@ -235,8 +241,12 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
TF_RET_CHECK(arg.initialized);
switch (arg.resource_kind) {
- case XlaResource::kVariable:
- return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
+ case XlaResource::kVariable: {
+ TensorShape representation_shape =
+ options_.variable_representation_shape_fn(arg.shape, arg.type);
+ return TensorShapeToXLAShape(arg.type, representation_shape,
+ xla_shape);
+ }
case XlaResource::kTensorArray: {
if (arg.tensor_array_size < 0) {
return errors::InvalidArgument(
@@ -310,16 +320,116 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
return Status::OK();
}
+// Builds the XLA computation.
+//
+// `retvals` is the list of retvals produced by _Retval operators, in index
+// order. `variable_map` is a map from variable ID numbers to XlaOpContext
+// variable states, generated by the symbolic evaluation.
+// If `return_updated_values_for_all_resources` is true, all resources will be
+// included in `resource_updates`, regardless of whether their value changed.
+// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
+// Sets `*resource_updates` to a description of resources whose values are
+// written by the computation; the variable writes are the last
+// `resource_updates.size()` return values from the computation. Each entry in
+// `resource_updates` is a (input_index, type) pair, where `input_index` is the
+// index of a resource variable argument to the computation, and `type` is the
+// type of the final output.
+Status BuildComputation(
+ const std::vector<XlaCompiler::Argument>& args,
+ const std::vector<int>& arg_cores,
+ const std::vector<XlaExpression>& retvals,
+ const std::vector<std::unique_ptr<XlaResource>>& resources,
+ bool return_updated_values_for_all_resources,
+ xla::ComputationBuilder* builder, xla::Computation* computation,
+ int* num_computation_outputs, int* num_nonconst_outputs,
+ std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
+ std::vector<xla::ComputationDataHandle> elems;
+ elems.reserve(retvals.size());
+ for (const XlaExpression& retval : retvals) {
+ if (!retval.has_constant_value()) {
+ elems.push_back(retval.handle());
+ }
+ }
+ *num_nonconst_outputs = elems.size();
+
+ // Add return values for resources whose values have changed.
+ std::vector<const XlaResource*> arg_resources;
+ arg_resources.reserve(resources.size());
+ for (const auto& resource : resources) {
+ if (resource->arg_num() >= 0) {
+ arg_resources.push_back(resource.get());
+ }
+ }
+ std::sort(arg_resources.begin(), arg_resources.end(),
+ [](const XlaResource* a, const XlaResource* b) {
+ return a->arg_num() < b->arg_num();
+ });
+
+ for (const XlaResource* resource : arg_resources) {
+ const XlaCompiler::Argument& arg = args[resource->arg_num()];
+ const int core = arg_cores[resource->arg_num()];
+ DCHECK_LT(resource->arg_num(), arg_cores.size());
+ bool modified =
+ resource->value().handle() != resource->initial_value().handle();
+ // TensorArray gradients were modified if their values changed or there are
+ // any newly created gradients.
+ for (const auto& grad : resource->tensor_array_gradients()) {
+ modified = modified ||
+ grad.second->value().handle() !=
+ grad.second->initial_value().handle() ||
+ arg.tensor_array_gradients.count(grad.first) == 0;
+ }
+ if (return_updated_values_for_all_resources || modified) {
+ resource_updates->emplace_back();
+ XlaCompiler::ResourceUpdate& update = resource_updates->back();
+ update.input_index = resource->arg_num();
+ update.type = resource->type();
+ update.shape = resource->shape();
+ update.modified = modified;
+ for (const auto& grad : resource->tensor_array_gradients()) {
+ update.tensor_array_gradients_accessed.insert(grad.first);
+ }
+
+ // Request that the value be returned on a specific core.
+ xla::ScopedShardingAssignment assign_sharding(
+ builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
+ : xla::sharding_builder::AssignDevice(core));
+
+ xla::ComputationDataHandle handle;
+ TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
+
+ // Since we can't change the sharding metadata of <value> as this point,
+ // create a tuple/get-tuple-element combination so that sharding
+ // assignment will be placed on this value, which will cause the resource
+ // update to be returned from the same device that provided the resource.
+ handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
+
+ elems.push_back(handle);
+ }
+ }
+
+ *num_computation_outputs = elems.size();
+
+ // Builds the XLA computation.
+ builder->Tuple(elems);
+ xla::StatusOr<xla::Computation> computation_status = builder->Build();
+ if (!computation_status.ok()) {
+ return computation_status.status();
+ }
+ *computation = computation_status.ConsumeValueOrDie();
+ return Status::OK();
+}
+
+} // namespace
+
// Builds XLA computations for each of the arguments to the computation.
// `args` are the arguments to the computation.
-Status BuildArguments(const Graph& graph,
- const std::vector<XlaCompiler::Argument>& args,
- bool use_tuple_arg, xla::ComputationBuilder* builder,
- XlaContext* context, std::vector<int>* arg_cores,
- std::vector<XlaExpression>* arg_expressions,
- std::vector<int>* input_mapping,
- std::vector<xla::Shape>* input_shapes,
- bool is_entry_computation) {
+Status XlaCompiler::BuildArguments(
+ const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
+ bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context,
+ std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
+ std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
+ bool is_entry_computation) {
arg_expressions->resize(args.size());
*arg_cores = std::vector<int>(args.size(), -1);
@@ -374,8 +484,8 @@ Status BuildArguments(const Graph& graph,
std::vector<xla::Shape> arg_shapes(input_mapping->size());
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
// Computes the shapes of non-constant arguments.
- TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument(
- args[(*input_mapping)[i]], &arg_shapes[i]));
+ TF_RETURN_IF_ERROR(
+ XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
}
if (use_tuple_arg) {
@@ -472,108 +582,6 @@ Status BuildArguments(const Graph& graph,
return Status::OK();
}
-// Builds the XLA computation.
-//
-// `retvals` is the list of retvals produced by _Retval operators, in index
-// order. `variable_map` is a map from variable ID numbers to XlaOpContext
-// variable states, generated by the symbolic evaluation.
-// If `return_updated_values_for_all_resources` is true, all resources will be
-// included in `resource_updates`, regardless of whether their value changed.
-// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
-// Sets `*resource_updates` to a description of resources whose values are
-// written by the computation; the variable writes are the last
-// `resource_updates.size()` return values from the computation. Each entry in
-// `resource_updates` is a (input_index, type) pair, where `input_index` is the
-// index of a resource variable argument to the computation, and `type` is the
-// type of the final output.
-Status BuildComputation(
- const std::vector<XlaCompiler::Argument>& args,
- const std::vector<int>& arg_cores,
- const std::vector<XlaExpression>& retvals,
- const std::vector<std::unique_ptr<XlaResource>>& resources,
- bool return_updated_values_for_all_resources,
- xla::ComputationBuilder* builder, xla::Computation* computation,
- int* num_computation_outputs, int* num_nonconst_outputs,
- std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
- std::vector<xla::ComputationDataHandle> elems;
- elems.reserve(retvals.size());
- for (const XlaExpression& retval : retvals) {
- if (!retval.has_constant_value()) {
- elems.push_back(retval.handle());
- }
- }
- *num_nonconst_outputs = elems.size();
-
- // Add return values for resources whose values have changed.
- std::vector<const XlaResource*> arg_resources;
- arg_resources.reserve(resources.size());
- for (const auto& resource : resources) {
- if (resource->arg_num() >= 0) {
- arg_resources.push_back(resource.get());
- }
- }
- std::sort(arg_resources.begin(), arg_resources.end(),
- [](const XlaResource* a, const XlaResource* b) {
- return a->arg_num() < b->arg_num();
- });
-
- for (const XlaResource* resource : arg_resources) {
- const XlaCompiler::Argument& arg = args[resource->arg_num()];
- const int core = arg_cores[resource->arg_num()];
- DCHECK_LT(resource->arg_num(), arg_cores.size());
- bool modified =
- resource->value().handle() != resource->initial_value().handle();
- // TensorArray gradients were modified if their values changed or there are
- // any newly created gradients.
- for (const auto& grad : resource->tensor_array_gradients()) {
- modified = modified ||
- grad.second->value().handle() !=
- grad.second->initial_value().handle() ||
- arg.tensor_array_gradients.count(grad.first) == 0;
- }
- if (return_updated_values_for_all_resources || modified) {
- resource_updates->emplace_back();
- XlaCompiler::ResourceUpdate& update = resource_updates->back();
- update.input_index = resource->arg_num();
- update.type = resource->type();
- update.shape = resource->shape();
- update.modified = modified;
- for (const auto& grad : resource->tensor_array_gradients()) {
- update.tensor_array_gradients_accessed.insert(grad.first);
- }
-
- // Request that the value be returned on a specific core.
- xla::ScopedShardingAssignment assign_sharding(
- builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>()
- : xla::sharding_builder::AssignDevice(core));
-
- xla::ComputationDataHandle handle;
- TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
-
- // Since we can't change the sharding metadata of <value> as this point,
- // create a tuple/get-tuple-element combination so that sharding
- // assignment will be placed on this value, which will cause the resource
- // update to be returned from the same device that provided the resource.
- handle = builder->GetTupleElement(builder->Tuple({handle}), 0);
-
- elems.push_back(handle);
- }
- }
-
- *num_computation_outputs = elems.size();
-
- // Builds the XLA computation.
- builder->Tuple(elems);
- xla::StatusOr<xla::Computation> computation_status = builder->Build();
- if (!computation_status.ok()) {
- return computation_status.status();
- }
- *computation = computation_status.ConsumeValueOrDie();
- return Status::OK();
-}
-
-} // namespace
-
Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
string const& name,
std::unique_ptr<Graph> graph,
@@ -598,7 +606,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
xla::ComputationBuilder builder(client(), name);
XlaContext* context =
new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
- options.resolve_compile_time_constants);
+ options.resolve_compile_time_constants,
+ &options_.variable_representation_shape_fn);
core::ScopedUnref context_unref(context);
std::vector<XlaExpression> arg_expressions;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index b86c82c0ab..c4449bc4be 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -29,6 +29,9 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
+
+class XlaContext;
+
// The XlaCompiler class is responsible for compilation of a self-contained
// subgraph of a TensorFlow computation using the XLA linear algebra runtime.
// It does a symbolic execution of the graph starting from specific input
@@ -239,6 +242,12 @@ class XlaCompiler {
// for CPU.
bool allow_cpu_custom_calls = false;
+ // If set, the XLA representation of variables represented to XLA as the
+ // shape given by this shape function. Variables are reshaped to this shape
+ // on write, and reshaped to their original shape on read.
+ std::function<TensorShape(const TensorShape&, DataType)>
+ variable_representation_shape_fn;
+
// If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation
// device is created, and can be used to create metadata objects
@@ -278,7 +287,7 @@ class XlaCompiler {
// Returns the shape of the XLA parameter for an argument 'arg'.
// See the class comment for more details about the argument passing
// convention.
- static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
+ Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.
@@ -299,6 +308,17 @@ class XlaCompiler {
// Returns the optimized graph object in this function body.
std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody);
+ // Builds XLA computations for each of the arguments to the computation.
+ // `args` are the arguments to the computation.
+ Status BuildArguments(const Graph& graph,
+ const std::vector<XlaCompiler::Argument>& args,
+ bool use_tuple_arg, xla::ComputationBuilder* builder,
+ XlaContext* context, std::vector<int>* arg_cores,
+ std::vector<XlaExpression>* arg_expressions,
+ std::vector<int>* input_mapping,
+ std::vector<xla::Shape>* input_shapes,
+ bool is_entry_computation);
+
// Graph compiler needs to know how to get an optimized graph from a function
// body.
friend class GraphCompiler;
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 65de4dbad7..a18eeacd41 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/data_flow_ops.h"
#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -683,5 +684,128 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
<< status.error_message();
}
+// Tests a simple graph that reads and writes a variable.
+TEST_F(XlaCompilerTest, Variables) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+ auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
+ auto write = ops::AssignAddVariableOp(scope, var, a);
+ auto read = ops::ReadVariableOp(
+ scope.WithControlDependencies(std::vector<Operation>{write}), var,
+ DT_INT32);
+ auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
+ auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(2);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({2});
+ args[1].kind = XlaCompiler::Argument::kResource;
+ args[1].resource_kind = XlaResource::kVariable;
+ args[1].initialized = true;
+ args[1].type = DT_INT32;
+ args[1].shape = TensorShape({2});
+
+ // Compiles the graph.
+ XlaCompiler compiler(DefaultOptions());
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+
+ // Tests that the generated computation works.
+ std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal::CreateR1<int32>({7, 42});
+ std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal::CreateR1<int32>({-3, 101});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<xla::GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::GlobalData> actual =
+ client_
+ ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+ .ConsumeValueOrDie();
+ std::unique_ptr<xla::Literal> actual_literal =
+ client_->Transfer(*actual).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::Literal> expected0 =
+ xla::Literal::CreateR1<int32>({5, 144});
+ std::unique_ptr<xla::Literal> expected1 =
+ xla::Literal::CreateR1<int32>({4, 143});
+ std::unique_ptr<xla::Literal> expected_literal =
+ xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+}
+
+// Tests a simple graph that reads and writes a variable, with a
+// variable_representation_shape_fn passed to the compiler that flattens all
+// variable tensors to vectors.
+TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
+ Scope scope = Scope::NewRootScope().ExitOnError();
+ auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
+ auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
+ auto write = ops::AssignAddVariableOp(scope, var, a);
+ auto read = ops::ReadVariableOp(
+ scope.WithControlDependencies(std::vector<Operation>{write}), var,
+ DT_INT32);
+ auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
+ auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
+
+ // Builds a description of the arguments.
+ std::vector<XlaCompiler::Argument> args(2);
+ args[0].kind = XlaCompiler::Argument::kParameter;
+ args[0].type = DT_INT32;
+ args[0].shape = TensorShape({2, 2});
+ args[1].kind = XlaCompiler::Argument::kResource;
+ args[1].resource_kind = XlaResource::kVariable;
+ args[1].initialized = true;
+ args[1].type = DT_INT32;
+ args[1].shape = TensorShape({2, 2});
+
+ // Compiles the graph.
+ XlaCompiler::Options options = DefaultOptions();
+ options.variable_representation_shape_fn = [](const TensorShape& shape,
+ DataType type) {
+ return TensorShape({shape.num_elements()});
+ };
+ XlaCompiler compiler(options);
+
+ XlaCompiler::CompilationResult result;
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
+
+ // Tests that the generated computation works.
+ std::unique_ptr<xla::Literal> param0_literal =
+ xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}});
+ std::unique_ptr<xla::Literal> param1_literal =
+ xla::Literal::CreateR1<int32>({22, 11, 33, 404});
+ std::unique_ptr<xla::GlobalData> param0_data =
+ client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
+ std::unique_ptr<xla::GlobalData> param1_data =
+ client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::GlobalData> actual =
+ client_
+ ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
+ .ConsumeValueOrDie();
+ std::unique_ptr<xla::Literal> actual_literal =
+ client_->Transfer(*actual).ConsumeValueOrDie();
+
+ std::unique_ptr<xla::Literal> expected0 =
+ xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}});
+ std::unique_ptr<xla::Literal> expected1 =
+ xla::Literal::CreateR1<int32>({26, 66, 34, 401});
+ std::unique_ptr<xla::Literal> expected_literal =
+ xla::Literal::MakeTuple({expected0.get(), expected1.get()});
+ xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 73878955e3..8423921086 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -62,13 +62,16 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
args_ = std::move(args);
}
-XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
- bool allow_cpu_custom_calls,
- bool resolve_compile_time_constants)
+XlaContext::XlaContext(
+ XlaCompiler* compiler, xla::ComputationBuilder* builder,
+ bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
+ const std::function<TensorShape(const TensorShape&, DataType)>*
+ variable_representation_shape_fn)
: compiler_(compiler),
builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls),
- resolve_compile_time_constants_(resolve_compile_time_constants) {}
+ resolve_compile_time_constants_(resolve_compile_time_constants),
+ variable_representation_shape_fn_(variable_representation_shape_fn) {}
string XlaContext::DebugString() { return "TLA JIT context"; }
@@ -115,6 +118,11 @@ Status XlaContext::CreateResource(
return Status::OK();
}
+TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
+ DataType type) const {
+ return (*variable_representation_shape_fn_)(shape, type);
+}
+
const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) {
return LookupOrCreate(type, &max_func_, [this, type] {
const string type_string = DataTypeString(type);
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index fac0352ae8..00fbaba37c 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -44,7 +44,9 @@ class XlaContext : public ResourceBase {
// Creates a new XlaContext.
XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder,
- bool allow_cpu_custom_calls, bool resolve_compile_time_constants);
+ bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
+ const std::function<TensorShape(const TensorShape&, DataType)>*
+ variable_representation_shape_fn);
// Virtual method defined by ResourceBase.
string DebugString() override;
@@ -86,6 +88,11 @@ class XlaContext : public ResourceBase {
return resources_;
}
+ // Returns the XLA shape to be used to represent a variable of TF `shape`
+ // and `type`.
+ TensorShape VariableRepresentationShape(const TensorShape& shape,
+ DataType type) const;
+
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
// separate specialization of the computation for each DataType.
@@ -133,6 +140,11 @@ class XlaContext : public ResourceBase {
// Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_;
+ // A function that describes how variable shapes should be represented
+ // in XLA. Variable values will be reshaped to this shape. Must be non-null.
+ const std::function<TensorShape(const TensorShape&, DataType)>*
+ variable_representation_shape_fn_;
+
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::Computation>;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index ee29158646..c4bb90d587 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -302,10 +302,19 @@ Status XlaOpKernelContext::ReadVariableInput(
"Type mismatch for read of variable ", variable->name(), ". Expected ",
DataTypeString(type), "; got ", DataTypeString(variable->type()));
}
- *value = variable->value();
if (shape) {
*shape = variable->shape();
}
+
+ XlaContext& xla_context = XlaContext::Get(context_);
+ TensorShape representation_shape = xla_context.VariableRepresentationShape(
+ variable->shape(), variable->type());
+ if (representation_shape == variable->shape()) {
+ *value = variable->value();
+ } else {
+ *value =
+ builder()->Reshape(variable->value(), variable->shape().dim_sizes());
+ }
return Status::OK();
}
@@ -400,8 +409,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
return Status::OK();
}
-Status XlaOpKernelContext::AssignVariable(
- int input_index, DataType type, const xla::ComputationDataHandle& handle) {
+Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
+ xla::ComputationDataHandle handle) {
TF_RET_CHECK(handle.handle() != 0);
const XlaExpression* expression =
@@ -419,6 +428,13 @@ Status XlaOpKernelContext::AssignVariable(
XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape));
TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape));
+
+ XlaContext& xla_context = XlaContext::Get(context_);
+ TensorShape representation_shape =
+ xla_context.VariableRepresentationShape(shape, type);
+ if (shape != representation_shape) {
+ handle = builder()->Reshape(handle, representation_shape.dim_sizes());
+ }
return variable->SetValue(handle);
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index e1fd0f55c6..4e4b97e0ce 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -175,7 +175,7 @@ class XlaOpKernelContext {
// variable has been initialized with a different type or with a
// different shape.
Status AssignVariable(int input_index, DataType type,
- const xla::ComputationDataHandle& handle);
+ xla::ComputationDataHandle handle);
// Helper routines for the OP_REQUIRES macros
void CtxFailure(const Status& s);