diff options
author | 2018-02-14 14:49:23 -0800 | |
---|---|---|
committer | 2018-02-14 14:53:40 -0800 | |
commit | aa92995493939b674cd54b13b5850cc185a6e7ae (patch) | |
tree | 5ce3c336a1ca4d1241af5060de7869a64ba48c30 /tensorflow/compiler | |
parent | 90159c53b83527bd088452769ea4e1b98667860c (diff) |
[TF:XLA] Add a hook to allow reshaping of TensorFlow variables when storing them in their XLA representation.
PiperOrigin-RevId: 185748660
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/tf2xla/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/literal_util.cc | 24 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/literal_util.h | 9 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.cc | 243 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler.h | 22 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_compiler_test.cc | 124 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.cc | 16 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_context.h | 14 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.cc | 22 | ||||
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_kernel.h | 2 |
10 files changed, 341 insertions, 136 deletions
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3c7dfef03d..fb82c2601c 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -312,6 +312,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index fcbd157c61..2c3cd658e0 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,20 +40,20 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status CopyLiteralToHostTensor(const xla::Literal& literal, + Tensor* host_tensor) { + TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && + xla::ShapeUtil::ElementsIn(literal.shape()) == + host_tensor->NumElements()); xla::PrimitiveType primitive_type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(target_type, &primitive_type)); + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(host_tensor->dtype(), &primitive_type)); if (literal.shape().element_type() != primitive_type) { return errors::InvalidArgument( "Cannot convert literal of type ", xla::PrimitiveType_Name(literal.shape().element_type()), - " to tensor of type ", DataTypeString(target_type)); + " to tensor of type ", DataTypeString(host_tensor->dtype())); } - - TensorShape shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); - *host_tensor = Tensor(target_type, shape); size_t total_bytes = host_tensor->TotalBytes(); if (total_bytes > 0) { const void* src_ptr = literal.untyped_data(); @@ -63,4 +63,12 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, return Status::OK(); } +Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, + Tensor* host_tensor) { + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); + *host_tensor = Tensor(target_type, shape); + return CopyLiteralToHostTensor(literal, host_tensor); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index fe08e83c23..f283b02368 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -29,7 +29,8 @@ namespace tensorflow { // unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); -// Copies 'literal' to 'host_tensor', which is allocated of type <target_type>. +// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of +// type <target_type>. // Fails if the literal's primitive type != // DataTypeToPrimitiveType(target_type). Note that <target_type> is not // derivable from the type of <literal>, because multiple tensorflow types map @@ -38,6 +39,12 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, Tensor* host_tensor); +// Copies the contents of 'literal' to a previously allocated tensor +// 'host_tensor'. The tensor and the literal must have the same number of +// elements and the same type. +Status CopyLiteralToHostTensor(const xla::Literal& literal, + Tensor* host_tensor); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 59e8830442..15bba46ac6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -109,6 +109,12 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); + + // The default variable representation shape is the identity function. + if (!options_.variable_representation_shape_fn) { + options_.variable_representation_shape_fn = + [](const TensorShape& shape, DataType type) { return shape; }; + } } XlaCompiler::~XlaCompiler() = default; @@ -223,8 +229,8 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, } // Computes the XLA shape for argument 'arg'. -/*static*/ Status XlaCompiler::XLAShapeForArgument( - const XlaCompiler::Argument& arg, xla::Shape* xla_shape) { +Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, + xla::Shape* xla_shape) { switch (arg.kind) { case XlaCompiler::Argument::kConstant: return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), @@ -235,8 +241,12 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { - case XlaResource::kVariable: - return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaResource::kVariable: { + TensorShape representation_shape = + options_.variable_representation_shape_fn(arg.shape, arg.type); + return TensorShapeToXLAShape(arg.type, representation_shape, + xla_shape); + } case XlaResource::kTensorArray: { if (arg.tensor_array_size < 0) { return errors::InvalidArgument( @@ -310,16 +320,116 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph, return Status::OK(); } +// Builds the XLA computation. +// +// `retvals` is the list of retvals produced by _Retval operators, in index +// order. `variable_map` is a map from variable ID numbers to XlaOpContext +// variable states, generated by the symbolic evaluation. +// If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. +// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// Sets `*resource_updates` to a description of resources whose values are +// written by the computation; the variable writes are the last +// `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a (input_index, type) pair, where `input_index` is the +// index of a resource variable argument to the computation, and `type` is the +// type of the final output. +Status BuildComputation( + const std::vector<XlaCompiler::Argument>& args, + const std::vector<int>& arg_cores, + const std::vector<XlaExpression>& retvals, + const std::vector<std::unique_ptr<XlaResource>>& resources, + bool return_updated_values_for_all_resources, + xla::ComputationBuilder* builder, xla::Computation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector<XlaCompiler::ResourceUpdate>* resource_updates) { + std::vector<xla::ComputationDataHandle> elems; + elems.reserve(retvals.size()); + for (const XlaExpression& retval : retvals) { + if (!retval.has_constant_value()) { + elems.push_back(retval.handle()); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for resources whose values have changed. + std::vector<const XlaResource*> arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num() >= 0) { + arg_resources.push_back(resource.get()); + } + } + std::sort(arg_resources.begin(), arg_resources.end(), + [](const XlaResource* a, const XlaResource* b) { + return a->arg_num() < b->arg_num(); + }); + + for (const XlaResource* resource : arg_resources) { + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + const int core = arg_cores[resource->arg_num()]; + DCHECK_LT(resource->arg_num(), arg_cores.size()); + bool modified = + resource->value().handle() != resource->initial_value().handle(); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients()) { + modified = modified || + grad.second->value().handle() != + grad.second->initial_value().handle() || + arg.tensor_array_gradients.count(grad.first) == 0; + } + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); + update.modified = modified; + for (const auto& grad : resource->tensor_array_gradients()) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + + // Request that the value be returned on a specific core. + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>() + : xla::sharding_builder::AssignDevice(core)); + + xla::ComputationDataHandle handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Since we can't change the sharding metadata of <value> as this point, + // create a tuple/get-tuple-element combination so that sharding + // assignment will be placed on this value, which will cause the resource + // update to be returned from the same device that provided the resource. + handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + + elems.push_back(handle); + } + } + + *num_computation_outputs = elems.size(); + + // Builds the XLA computation. + builder->Tuple(elems); + xla::StatusOr<xla::Computation> computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + return Status::OK(); +} + +} // namespace + // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. -Status BuildArguments(const Graph& graph, - const std::vector<XlaCompiler::Argument>& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, - XlaContext* context, std::vector<int>* arg_cores, - std::vector<XlaExpression>* arg_expressions, - std::vector<int>* input_mapping, - std::vector<xla::Shape>* input_shapes, - bool is_entry_computation) { +Status XlaCompiler::BuildArguments( + const Graph& graph, const std::vector<XlaCompiler::Argument>& args, + bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context, + std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions, + std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes, + bool is_entry_computation) { arg_expressions->resize(args.size()); *arg_cores = std::vector<int>(args.size(), -1); @@ -374,8 +484,8 @@ Status BuildArguments(const Graph& graph, std::vector<xla::Shape> arg_shapes(input_mapping->size()); for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument( - args[(*input_mapping)[i]], &arg_shapes[i])); + TF_RETURN_IF_ERROR( + XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); } if (use_tuple_arg) { @@ -472,108 +582,6 @@ Status BuildArguments(const Graph& graph, return Status::OK(); } -// Builds the XLA computation. -// -// `retvals` is the list of retvals produced by _Retval operators, in index -// order. `variable_map` is a map from variable ID numbers to XlaOpContext -// variable states, generated by the symbolic evaluation. -// If `return_updated_values_for_all_resources` is true, all resources will be -// included in `resource_updates`, regardless of whether their value changed. -// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*resource_updates` to a description of resources whose values are -// written by the computation; the variable writes are the last -// `resource_updates.size()` return values from the computation. Each entry in -// `resource_updates` is a (input_index, type) pair, where `input_index` is the -// index of a resource variable argument to the computation, and `type` is the -// type of the final output. -Status BuildComputation( - const std::vector<XlaCompiler::Argument>& args, - const std::vector<int>& arg_cores, - const std::vector<XlaExpression>& retvals, - const std::vector<std::unique_ptr<XlaResource>>& resources, - bool return_updated_values_for_all_resources, - xla::ComputationBuilder* builder, xla::Computation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, - std::vector<XlaCompiler::ResourceUpdate>* resource_updates) { - std::vector<xla::ComputationDataHandle> elems; - elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { - elems.push_back(retval.handle()); - } - } - *num_nonconst_outputs = elems.size(); - - // Add return values for resources whose values have changed. - std::vector<const XlaResource*> arg_resources; - arg_resources.reserve(resources.size()); - for (const auto& resource : resources) { - if (resource->arg_num() >= 0) { - arg_resources.push_back(resource.get()); - } - } - std::sort(arg_resources.begin(), arg_resources.end(), - [](const XlaResource* a, const XlaResource* b) { - return a->arg_num() < b->arg_num(); - }); - - for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num()]; - const int core = arg_cores[resource->arg_num()]; - DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = - resource->value().handle() != resource->initial_value().handle(); - // TensorArray gradients were modified if their values changed or there are - // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients()) { - modified = modified || - grad.second->value().handle() != - grad.second->initial_value().handle() || - arg.tensor_array_gradients.count(grad.first) == 0; - } - if (return_updated_values_for_all_resources || modified) { - resource_updates->emplace_back(); - XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num(); - update.type = resource->type(); - update.shape = resource->shape(); - update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients()) { - update.tensor_array_gradients_accessed.insert(grad.first); - } - - // Request that the value be returned on a specific core. - xla::ScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional<xla::OpSharding>() - : xla::sharding_builder::AssignDevice(core)); - - xla::ComputationDataHandle handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - - // Since we can't change the sharding metadata of <value> as this point, - // create a tuple/get-tuple-element combination so that sharding - // assignment will be placed on this value, which will cause the resource - // update to be returned from the same device that provided the resource. - handle = builder->GetTupleElement(builder->Tuple({handle}), 0); - - elems.push_back(handle); - } - } - - *num_computation_outputs = elems.size(); - - // Builds the XLA computation. - builder->Tuple(elems); - xla::StatusOr<xla::Computation> computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - return Status::OK(); -} - -} // namespace - Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr<Graph> graph, @@ -598,7 +606,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, xla::ComputationBuilder builder(client(), name); XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants); + options.resolve_compile_time_constants, + &options_.variable_representation_shape_fn); core::ScopedUnref context_unref(context); std::vector<XlaExpression> arg_expressions; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index b86c82c0ab..c4449bc4be 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -29,6 +29,9 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { + +class XlaContext; + // The XlaCompiler class is responsible for compilation of a self-contained // subgraph of a TensorFlow computation using the XLA linear algebra runtime. // It does a symbolic execution of the graph starting from specific input @@ -239,6 +242,12 @@ class XlaCompiler { // for CPU. bool allow_cpu_custom_calls = false; + // If set, the XLA representation of variables represented to XLA as the + // shape given by this shape function. Variables are reshaped to this shape + // on write, and reshaped to their original shape on read. + std::function<TensorShape(const TensorShape&, DataType)> + variable_representation_shape_fn; + // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects @@ -278,7 +287,7 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -299,6 +308,17 @@ class XlaCompiler { // Returns the optimized graph object in this function body. std::unique_ptr<Graph> GetGraph(const FunctionBody* fbody); + // Builds XLA computations for each of the arguments to the computation. + // `args` are the arguments to the computation. + Status BuildArguments(const Graph& graph, + const std::vector<XlaCompiler::Argument>& args, + bool use_tuple_arg, xla::ComputationBuilder* builder, + XlaContext* context, std::vector<int>* arg_cores, + std::vector<XlaExpression>* arg_expressions, + std::vector<int>* input_mapping, + std::vector<xla::Shape>* input_shapes, + bool is_entry_computation); + // Graph compiler needs to know how to get an optimized graph from a function // body. friend class GraphCompiler; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 65de4dbad7..a18eeacd41 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -683,5 +684,128 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { << status.error_message(); } +// Tests a simple graph that reads and writes a variable. +TEST_F(XlaCompilerTest, Variables) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + auto write = ops::AssignAddVariableOp(scope, var, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector<Operation>{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + // Tests that the generated computation works. + std::unique_ptr<xla::Literal> param0_literal = + xla::Literal::CreateR1<int32>({7, 42}); + std::unique_ptr<xla::Literal> param1_literal = + xla::Literal::CreateR1<int32>({-3, 101}); + std::unique_ptr<xla::GlobalData> param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr<xla::GlobalData> param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr<xla::GlobalData> actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr<xla::Literal> actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr<xla::Literal> expected0 = + xla::Literal::CreateR1<int32>({5, 144}); + std::unique_ptr<xla::Literal> expected1 = + xla::Literal::CreateR1<int32>({4, 143}); + std::unique_ptr<xla::Literal> expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + +// Tests a simple graph that reads and writes a variable, with a +// variable_representation_shape_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + auto write = ops::AssignAddVariableOp(scope, var, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector<Operation>{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector<XlaCompiler::Argument> args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.variable_representation_shape_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + // Tests that the generated computation works. + std::unique_ptr<xla::Literal> param0_literal = + xla::Literal::CreateR2<int32>({{4, 55}, {1, -3}}); + std::unique_ptr<xla::Literal> param1_literal = + xla::Literal::CreateR1<int32>({22, 11, 33, 404}); + std::unique_ptr<xla::GlobalData> param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr<xla::GlobalData> param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr<xla::GlobalData> actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr<xla::Literal> actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr<xla::Literal> expected0 = + xla::Literal::CreateR2<int32>({{27, 67}, {35, 402}}); + std::unique_ptr<xla::Literal> expected1 = + xla::Literal::CreateR1<int32>({26, 66, 34, 401}); + std::unique_ptr<xla::Literal> expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 73878955e3..8423921086 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -62,13 +62,16 @@ void XlaContext::set_args(std::vector<XlaExpression> args) { args_ = std::move(args); } -XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, - bool allow_cpu_custom_calls, - bool resolve_compile_time_constants) +XlaContext::XlaContext( + XlaCompiler* compiler, xla::ComputationBuilder* builder, + bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + const std::function<TensorShape(const TensorShape&, DataType)>* + variable_representation_shape_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), - resolve_compile_time_constants_(resolve_compile_time_constants) {} + resolve_compile_time_constants_(resolve_compile_time_constants), + variable_representation_shape_fn_(variable_representation_shape_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } @@ -115,6 +118,11 @@ Status XlaContext::CreateResource( return Status::OK(); } +TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, + DataType type) const { + return (*variable_representation_shape_fn_)(shape, type); +} + const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [this, type] { const string type_string = DataTypeString(type); diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index fac0352ae8..00fbaba37c 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -44,7 +44,9 @@ class XlaContext : public ResourceBase { // Creates a new XlaContext. XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants); + bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + const std::function<TensorShape(const TensorShape&, DataType)>* + variable_representation_shape_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -86,6 +88,11 @@ class XlaContext : public ResourceBase { return resources_; } + // Returns the XLA shape to be used to represent a variable of TF `shape` + // and `type`. + TensorShape VariableRepresentationShape(const TensorShape& shape, + DataType type) const; + // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. @@ -133,6 +140,11 @@ class XlaContext : public ResourceBase { // Holds ownership of resources. The resources are not ordered. std::vector<std::unique_ptr<XlaResource>> resources_; + // A function that describes how variable shapes should be represented + // in XLA. Variable values will be reshaped to this shape. Must be non-null. + const std::function<TensorShape(const TensorShape&, DataType)>* + variable_representation_shape_fn_; + // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map<DataType, xla::Computation>; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index ee29158646..c4bb90d587 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -302,10 +302,19 @@ Status XlaOpKernelContext::ReadVariableInput( "Type mismatch for read of variable ", variable->name(), ". Expected ", DataTypeString(type), "; got ", DataTypeString(variable->type())); } - *value = variable->value(); if (shape) { *shape = variable->shape(); } + + XlaContext& xla_context = XlaContext::Get(context_); + TensorShape representation_shape = xla_context.VariableRepresentationShape( + variable->shape(), variable->type()); + if (representation_shape == variable->shape()) { + *value = variable->value(); + } else { + *value = + builder()->Reshape(variable->value(), variable->shape().dim_sizes()); + } return Status::OK(); } @@ -400,8 +409,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { return Status::OK(); } -Status XlaOpKernelContext::AssignVariable( - int input_index, DataType type, const xla::ComputationDataHandle& handle) { +Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, + xla::ComputationDataHandle handle) { TF_RET_CHECK(handle.handle() != 0); const XlaExpression* expression = @@ -419,6 +428,13 @@ Status XlaOpKernelContext::AssignVariable( XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); + + XlaContext& xla_context = XlaContext::Get(context_); + TensorShape representation_shape = + xla_context.VariableRepresentationShape(shape, type); + if (shape != representation_shape) { + handle = builder()->Reshape(handle, representation_shape.dim_sizes()); + } return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index e1fd0f55c6..4e4b97e0ce 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -175,7 +175,7 @@ class XlaOpKernelContext { // variable has been initialized with a different type or with a // different shape. Status AssignVariable(int input_index, DataType type, - const xla::ComputationDataHandle& handle); + xla::ComputationDataHandle handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); |