diff options
author | Peter Hawkins <phawkins@google.com> | 2018-05-15 10:33:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-15 10:36:11 -0700 |
commit | a7a3bb3df12c632b81bf1b23f8405f92a0c903c3 (patch) | |
tree | 5291bb0bed562ce40c5a13014c6c955c71e4e10a | |
parent | 2c83cddab0cd0de78f863e47d81b4427d6519eb7 (diff) |
Automated g4 rollback of changelist 196683444
PiperOrigin-RevId: 196691101
20 files changed, 247 insertions, 475 deletions
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index fee46280e9..868d752927 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -551,16 +551,14 @@ TEST(TFCompileTest, HloProfiling) { auto header = HasSubstr("Execution profile for"); auto total_cycles_profile_line = HasSubstr("[total]"); auto dot_profile_line = HasSubstr( - "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%arg1.0.1)"); auto add_profile_line = HasSubstr( - "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " + "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} " "%arg1.0.1)"); auto tuple_profile_line = HasSubstr( "%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} " - "%dot.0.4, f32[2,2]{1,0} %add.0.6)"); - auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)"); - auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)"); + "%dot.0.2, f32[2,2]{1,0} %add.0.5)"); EXPECT_THAT(hlo_profile_lines, IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index 9d856346ec..86a9fd3b8e 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - const XlaDevice::Metadata* metadata = nullptr; + const XlaDevice::Metadata* metadata; Status s = XlaDevice::GetMetadata(ctx, &metadata); bool allocate_xla_tensors = s.ok(); @@ -153,9 +153,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId); options.device_allocator = xla_allocator; - if (metadata) { - options.shape_representation_fn = metadata->shape_representation_fn(); - } + // TODO(b/77671268): We don't set variable_representation_shape_fn here. This + // is restricted to Variables, but we need something like this to apply to + // normal Tensors too. const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; @@ -164,11 +164,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { for (int i : constants_) { constant_args.insert({i, ctx->input(i)}); } - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = true; - OP_REQUIRES_OK( - ctx, cache->Compile(options, function_, constant_args, variables, ctx, - &kernel, &executable, &compile_options)); + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args, + variables, ctx, &kernel, &executable, + /*compile_options=*/nullptr)); VLOG(1) << "Executing XLA Computation..."; diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index ea9e036604..bc07dbd7bd 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -50,11 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr<XlaDevice> device; - TF_RETURN_IF_ERROR( - XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, - name_prefix, registration, - /*transfer_as_literal=*/false, - /*shape_representation_fn=*/{}, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, + DEVICE_CPU_XLA_JIT, options, name_prefix, + registration, + /*transfer_as_literal=*/false, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 9ee5b04e80..70263b1ff9 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -110,9 +110,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( const string& jit_device_name, const SessionOptions& options, const string& name_prefix, const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - std::unique_ptr<XlaDevice>* device) { + bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -131,19 +129,17 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(), strings::StrCat("device: ", device_name, " device")); - device->reset(new XlaDevice( - options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, shape_representation_fn)); + device->reset(new XlaDevice(options, attrs, device_ordinal, + DeviceType(jit_device_name), + platform.ValueOrDie(), transfer_as_literal)); return Status::OK(); } -XlaDevice::Metadata::Metadata( - int device_ordinal, se::Platform* platform, const DeviceType& device_type, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) +XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform, + const DeviceType& device_type) : device_ordinal_(device_ordinal), device_type_(device_type), - platform_(platform), - shape_representation_fn_(std::move(shape_representation_fn)) {} + platform_(platform) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -174,20 +170,17 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { return Status::OK(); } -XlaDevice::XlaDevice( - const SessionOptions& options, const DeviceAttributes& attrs, - int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn) +XlaDevice::XlaDevice(const SessionOptions& options, + const DeviceAttributes& attrs, int device_ordinal, + const DeviceType& jit_device_name, se::Platform* platform, + bool transfer_as_literal) : LocalDevice(options, attrs), - xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn), + xla_metadata_(device_ordinal, platform, jit_device_name), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), platform_(platform), - transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(shape_representation_fn) { + transfer_as_literal_(transfer_as_literal) { VLOG(1) << "Created XLA device " << jit_device_name; } @@ -239,8 +232,8 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() { // gpu_device_info_->default_context. gpu_device_info_ = absl::make_unique<GpuDeviceInfo>(); gpu_device_info_->stream = stream; - gpu_device_info_->default_context = new XlaDeviceContext( - stream, client(), transfer_as_literal_, shape_representation_fn_); + gpu_device_info_->default_context = + new XlaDeviceContext(stream, client(), transfer_as_literal_); set_tensorflow_gpu_device_info(gpu_device_info_.get()); } @@ -254,8 +247,7 @@ Status XlaDevice::FillContextMap(const Graph* graph, TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); // Call GetAllocator for the side-effect of ensuring the allocator is created. GetAllocator({}); - auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_, - shape_representation_fn_); + auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_); for (Node* n : graph->nodes()) { VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name(); ctx->Ref(); @@ -302,8 +294,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape()); Notification n; TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream()); - XlaTransferManager manager(stream, client(), transfer_as_literal_, - shape_representation_fn_); + XlaTransferManager manager(stream, client(), transfer_as_literal_); manager.CopyCPUTensorToDevice(&parsed, this, ©, [&n, &status](const Status& s) { status = s; diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index d5d345d43b..3ae87308cc 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -17,7 +17,8 @@ limitations under the License. // runtime. // // Operators assigned to an XlaDevice are compiled into XLA computations. -// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers. +// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state +// is managed by XLA. // // XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU), // under different names (e.g., XLA_CPU or XLA_GPU). @@ -26,7 +27,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_ #include "tensorflow/compiler/jit/xla_tensor.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -50,8 +50,7 @@ class XlaDevice : public LocalDevice { class Metadata { public: Metadata(int device_ordinal, se::Platform* platform, - const DeviceType& device_type, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + const DeviceType& device_type); // The index of the device on this host. int device_ordinal() const; @@ -59,15 +58,11 @@ class XlaDevice : public LocalDevice { se::Platform* platform() const; xla::LocalClient* client() const; const DeviceType& jit_device_type() const; - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const { - return shape_representation_fn_; - } private: const int device_ordinal_; const DeviceType device_type_; se::Platform* platform_; // Not owned. - XlaCompiler::ShapeRepresentationFn shape_representation_fn_; TF_DISALLOW_COPY_AND_ASSIGN(Metadata); }; @@ -81,19 +76,16 @@ class XlaDevice : public LocalDevice { // 'transfer_as_literal' is true if device<->host transfers must be done using // XLA's TransferLiteral{To,From}Device interface. If false, we can use // ThenMemcpy instead. - static Status Create( - const string& platform_name, const string& device_name, - int device_ordinal, const string& jit_device_name, - const SessionOptions& options, const string& name_prefix, - const XlaOpRegistry::DeviceRegistration& registration, - bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - std::unique_ptr<XlaDevice>* device); + static Status Create(const string& platform_name, const string& device_name, + int device_ordinal, const string& jit_device_name, + const SessionOptions& options, const string& name_prefix, + const XlaOpRegistry::DeviceRegistration& registration, + bool transfer_as_literal, + std::unique_ptr<XlaDevice>* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, - se::Platform* platform, bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn); + se::Platform* platform, bool transfer_as_literal); ~XlaDevice() override; Allocator* GetAllocator(AllocatorAttributes attr) override; @@ -124,8 +116,8 @@ class XlaDevice : public LocalDevice { // The name of the device that is used to compile Ops for this XlaDevice. DeviceType jit_device_name_; // Memory allocator associated with this device. - Allocator* xla_allocator_; // Not owned. - se::Platform* platform_; // Not owned. + Allocator* xla_allocator_; // Not owned. + se::Platform* platform_; // Not owned. // Stream associated with this device. Operations enqueued on this // stream are executed on the device. Operations include data // copying back and forth between CPU and the device, and @@ -134,7 +126,6 @@ class XlaDevice : public LocalDevice { // Must we use XLA's transfer manager for correct host<->device transfers? if // false, we can use ThenMemcpy() instead. bool transfer_as_literal_; - XlaCompiler::ShapeRepresentationFn shape_representation_fn_; // If set, holds default device context (that we must Unref) // and its stream. diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index ff30b62bad..bf8c1886a0 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -47,14 +47,13 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) { void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); } -XlaTransferManager::XlaTransferManager( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) +XlaTransferManager::XlaTransferManager(se::Stream* stream, + xla::LocalClient* client, + bool transfer_as_literal) : stream_(stream), client_(client), transfer_manager_(client->backend().transfer_manager()), - transfer_as_literal_(transfer_as_literal), - shape_representation_fn_(std::move(shape_representation_fn)) {} + transfer_as_literal_(transfer_as_literal) {} Status XlaTransferManager::TransferLiteralToDevice( const Tensor& host_tensor, Tensor* device_tensor) const { @@ -77,15 +76,7 @@ Status XlaTransferManager::TransferLiteralFromDevice( transfer_manager_->TransferLiteralFromDevice( stream_->parent(), shaped_buffer)); VLOG(1) << "Transfer from device as literal: " << literal->ToString(); - Tensor tensor; - TF_RETURN_IF_ERROR( - LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor)); - // Reshape the tensor back to its declared shape. - if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) { - return errors::Internal( - "Tensor::CopyFrom failed when copying from XLA device to CPU"); - } - return Status::OK(); + return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor); } void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, @@ -105,17 +96,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor); CHECK(xla_tensor); - - TensorShape shape; - if (shape_representation_fn_) { - shape = shape_representation_fn_(device_tensor->shape(), - device_tensor->dtype()); - } else { - shape = device_tensor->shape(); - } if (!xla_tensor->has_shaped_buffer()) { Status s = xla_tensor->AllocateShapedBuffer( - device_tensor->dtype(), shape, client_, + device_tensor->dtype(), device_tensor->shape(), client_, stream_->parent()->device_ordinal()); if (!s.ok()) { done(s); @@ -123,18 +106,12 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, } } + se::DeviceMemoryBase dev_dst_ptr = + XlaTensor::DeviceMemoryFromTensor(*device_tensor); Status status; if (transfer_as_literal_) { - Tensor reshaped_cpu_tensor; - if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) { - done(errors::Internal( - "Tensor::CopyFrom failed when copying from CPU to XLA device")); - return; - } - status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor); + status = TransferLiteralToDevice(*cpu_tensor, device_tensor); } else { - se::DeviceMemoryBase dev_dst_ptr = - XlaTensor::DeviceMemoryFromTensor(*device_tensor); stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. Status block_status = stream_->BlockHostUntilDone(); @@ -194,11 +171,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(Status::OK()); } -XlaDeviceContext::XlaDeviceContext( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) - : manager_(stream, client, transfer_as_literal, - std::move(shape_representation_fn)) {} +XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, + bool transfer_as_literal) + : manager_(stream, client, transfer_as_literal) {} void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 9af9655868..d7f5f1d208 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -19,7 +19,6 @@ limitations under the License. #include <memory> #include "tensorflow/compiler/jit/xla_tensor.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/allocator.h" @@ -46,9 +45,8 @@ class XlaDeviceAllocator : public Allocator { // Helper class for managing data transfers between host and XLA devices. class XlaTransferManager { public: - explicit XlaTransferManager( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client, + bool transfer_as_literal); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, StatusCallback done) const; @@ -71,8 +69,7 @@ class XlaTransferManager { // Transfer manager, for marshalling data to and from the device. xla::TransferManager* transfer_manager_; // True if we must use XLA's TransferManager for correct device transfers. - const bool transfer_as_literal_; - const XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + bool transfer_as_literal_; }; // DeviceContext for operators assigned to XlaDevice devices. The @@ -80,9 +77,8 @@ class XlaTransferManager { // wraps the methods in XlaTransferManager. class XlaDeviceContext : public DeviceContext { public: - explicit XlaDeviceContext( - se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal, - XlaCompiler::ShapeRepresentationFn shape_representation_fn); + explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client, + bool transfer_as_literal); void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 26842fbe5c..a8afbf9dcd 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -48,8 +48,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, Status status = XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, registration, - /*transfer_as_literal=*/false, - /*shape_representation_fn=*/{}, &device); + /*transfer_as_literal=*/false, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index d0c7a93651..6a0f557627 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -195,6 +195,11 @@ void XlaComputationLaunchContext::PopulateOutputs( OP_REQUIRES_OK( ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor)); + if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) { + OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer( + const_tensor.dtype(), const_tensor.shape(), + client_, stream->parent()->device_ordinal())); + } Device* device = dynamic_cast<Device*>(ctx->device()); OP_REQUIRES(ctx, device != nullptr, diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 213ab95a12..96dfc8d8f1 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -42,7 +42,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", "//tensorflow/python:random_seed", "//tensorflow/python:session", @@ -58,7 +58,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -72,7 +72,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -93,7 +93,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -111,7 +111,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:bitwise_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -127,7 +127,7 @@ tf_xla_py_test( tags = ["optonly"], deps = [ ":xla_test", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -141,7 +141,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -156,7 +156,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -170,7 +170,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -184,7 +184,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -209,7 +209,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradient_checker", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", @@ -225,7 +225,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -241,7 +241,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -263,7 +263,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -291,7 +291,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -307,7 +307,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -326,7 +326,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn", @@ -346,7 +346,7 @@ tf_xla_py_test( "//tensorflow/contrib/signal:signal_py", "//tensorflow/python:array_ops", "//tensorflow/python:extra_py_tests_deps", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", "//tensorflow/python:spectral_ops", ], @@ -360,7 +360,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -372,7 +372,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -388,7 +388,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -403,7 +403,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:image_ops", "//tensorflow/python:platform_test", ], @@ -431,7 +431,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:nn", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -446,7 +446,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -458,7 +458,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -472,7 +472,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -485,7 +485,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -498,7 +498,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -513,7 +513,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", "//tensorflow/python:platform_test", @@ -530,7 +530,7 @@ tf_xla_py_test( ], deps = [ ":xla_test", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", ], @@ -545,7 +545,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -561,7 +561,7 @@ tf_xla_py_test( "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -574,7 +574,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", ], ) @@ -586,7 +586,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -598,7 +598,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:training", @@ -613,7 +613,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -626,7 +626,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:platform_test", @@ -641,7 +641,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -657,7 +657,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -670,7 +670,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/contrib/stateless", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -684,7 +684,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -703,7 +703,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", ], @@ -716,7 +716,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", @@ -730,7 +730,7 @@ tf_xla_py_test( srcs = ["fused_batchnorm_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn", @@ -749,7 +749,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", "//tensorflow/python:nn_ops", @@ -768,7 +768,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/compiler/tf2xla/python:xla", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", "//tensorflow/python:training", ], @@ -783,7 +783,7 @@ tf_xla_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -795,7 +795,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -808,34 +808,21 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) -tf_xla_py_test( +cuda_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], - tags = ["optonly"], - deps = [ - ":xla_test", - "//tensorflow/python:array_ops", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_test( - name = "xla_device_gpu_test", - size = "small", - srcs = ["xla_device_gpu_test.py"], additional_deps = [ "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", ], ) @@ -852,6 +839,7 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", "//tensorflow/python:layers", "//tensorflow/python:math_ops", @@ -899,7 +887,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:variables", @@ -914,7 +902,7 @@ cuda_py_test( ":xla_test", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -952,7 +940,7 @@ tf_xla_py_test( srcs = ["fake_quant_ops_test.py"], deps = [ ":xla_test", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) @@ -964,7 +952,7 @@ tf_xla_py_test( deps = [ ":xla_test", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform_test", ], ) diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py deleted file mode 100644 index 1e30ebd55d..0000000000 --- a/tensorflow/compiler/tests/xla_device_gpu_test.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Test cases for XLA devices.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.client import session as session_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test - - -class XlaDeviceGpuTest(test.TestCase): - - def testCopiesToAndFromGpuWork(self): - """Tests that copies between GPU and XLA devices work.""" - if not test.is_gpu_available(): - return - - with session_lib.Session() as sess: - x = array_ops.placeholder(dtypes.float32, [2]) - with ops.device("GPU"): - y = x * 2 - with ops.device("device:XLA_CPU:0"): - z = y * y - with ops.device("GPU"): - w = y + z - result = sess.run(w, {x: [1.5, 0.5]}) - self.assertAllClose(result, [12., 2.], rtol=1e-3) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py index b707bd0963..f5c228f830 100644 --- a/tensorflow/compiler/tests/xla_device_test.py +++ b/tensorflow/compiler/tests/xla_device_test.py @@ -1,4 +1,4 @@ -# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,33 +18,30 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - -from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class XlaDeviceTest(XLATestCase): +class XlaDeviceTest(test.TestCase): def testCopies(self): - """Tests that copies onto and off XLA devices work.""" - shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3], - [16384, 1], [1, 16384], [1, 20000, 1, 1]] - for dtype in self.numeric_types: - for shape in shapes: - with self.test_session() as sess: - with ops.device("CPU"): - x = array_ops.placeholder(dtype, shape) - with self.test_scope(): - y = x + x - with ops.device("CPU"): - z = array_ops.identity(y) - - inputs = np.random.randint(-100, 100, shape).astype(dtype) - result = sess.run(z, {x: inputs}) - self.assertAllCloseAccordingToType(result, inputs + inputs) + """Tests that copies between GPU and XLA devices work.""" + if not test.is_gpu_available(): + return + + with session_lib.Session() as sess: + x = array_ops.placeholder(dtypes.float32, [2]) + with ops.device("GPU"): + y = x * 2 + with ops.device("device:XLA_CPU:0"): + z = y * y + with ops.device("GPU"): + w = y + z + result = sess.run(w, {x: [1.5, 0.5]}) + self.assertAllClose(result, [12., 2.], rtol=1e-3) if __name__ == "__main__": diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index cd57452302..4fca51f54d 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -325,7 +325,6 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:cpu_plugin", diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index a567226f55..70547290ea 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -55,24 +55,18 @@ class RetvalOp : public XlaOpKernel { } XlaContext& tc = XlaContext::Get(ctx); - if (tc.resolve_compile_time_constants() && - (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { + if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - TensorShape shape = ctx->InputShape(0); - TensorShape representation_shape = - tc.is_entry_computation() - ? tc.RepresentationShape(shape, ctx->input_type(0)) - : shape; // The core from which a return value is returned depends on the core - // assignment of the input to the retval. Since we can't change the core - // assignment of <input> as this point, we must always introduce a - // reshape here, even if the shape does not change. - xla::XlaOp reshape = - ctx->builder()->Reshape(input, representation_shape.dim_sizes()); - tc.AddRetval(index_, dtype_, shape, reshape); + // assignment of the input to the retval .Since we can't change the core + // assignment of <input> as this point, create a tuple/get-tuple-element + // combination so that the core will be set on them. + auto tuple_elem = + ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); + tc.AddRetval(index_, dtype_, tuple_elem); } } } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 962e5340c0..3d1946c332 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,9 +15,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include <deque> #include <numeric> -#include <vector> +#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/graph_compiler.h" @@ -27,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -38,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { @@ -108,9 +111,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) flib_runtime_ = pflr_->GetFLR(device_->name()); // The default variable representation shape is the identity function. - if (!options_.shape_representation_fn) { - options_.shape_representation_fn = [](const TensorShape& shape, - DataType type) { return shape; }; + if (!options_.variable_representation_shape_fn) { + options_.variable_representation_shape_fn = + [](const TensorShape& shape, DataType type) { return shape; }; } } @@ -227,25 +230,20 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, // Computes the XLA shape for argument 'arg'. Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, - bool is_entry_computation, xla::Shape* xla_shape) { switch (arg.kind) { case XlaCompiler::Argument::kConstant: - LOG(FATAL) << "Unreachable case"; - case XlaCompiler::Argument::kParameter: { - TensorShape shape = - is_entry_computation - ? options_.shape_representation_fn(arg.shape, arg.type) - : arg.shape; - return TensorShapeToXLAShape(arg.type, shape, xla_shape); - } + return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), + xla_shape); + case XlaCompiler::Argument::kParameter: + return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { case XlaResource::kVariable: { TensorShape representation_shape = - options_.shape_representation_fn(arg.shape, arg.type); + options_.variable_representation_shape_fn(arg.shape, arg.type); return TensorShapeToXLAShape(arg.type, representation_shape, xla_shape); } @@ -339,25 +337,16 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph, Status BuildComputation( const std::vector<XlaCompiler::Argument>& args, const std::vector<int>& arg_cores, - const std::vector<XlaContext::Retval>& retvals, + const std::vector<XlaExpression>& retvals, const std::vector<std::unique_ptr<XlaResource>>& resources, bool return_updated_values_for_all_resources, xla::XlaBuilder* builder, xla::XlaComputation* computation, int* num_computation_outputs, int* num_nonconst_outputs, - std::vector<XlaCompiler::OutputDescription>* outputs, std::vector<XlaCompiler::ResourceUpdate>* resource_updates) { std::vector<xla::XlaOp> elems; elems.reserve(retvals.size()); - for (int i = 0; i < retvals.size(); ++i) { - XlaCompiler::OutputDescription& output = (*outputs)[i]; - output.type = retvals[i].type; - output.shape = retvals[i].shape; - const XlaExpression& retval = retvals[i].expression; - if (retval.has_constant_value()) { - output.is_constant = true; - output.constant_value = retval.constant_value(); - } else { - output.is_constant = false; + for (const XlaExpression& retval : retvals) { + if (!retval.has_constant_value()) { elems.push_back(retval.handle()); } } @@ -501,8 +490,8 @@ Status XlaCompiler::BuildArguments( std::vector<xla::Shape> arg_shapes(input_mapping->size()); for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR(XLAShapeForArgument( - args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i])); + TF_RETURN_IF_ERROR( + XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); } if (use_tuple_arg) { @@ -578,8 +567,7 @@ Status XlaCompiler::BuildArguments( builder->ClearOpMetadata(); - // Fill in the handles in non-constant arguments, and reshape parameters - // back to their correct shapes. + // Fill in the handles in non-constant arguments. VLOG(2) << "XLA computation inputs:"; for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) { const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; @@ -598,9 +586,7 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kParameter: - // Reshape parameters back to their correct shapes. - arg_expression.set_handle( - builder->Reshape(arg_handles[i], arg.shape.dim_sizes())); + arg_expression.set_handle(arg_handles[i]); break; case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kInvalid: @@ -675,10 +661,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, FunctionalizeControlFlow(graph.get(), local_flib_def_.get())); xla::XlaBuilder builder(name); - XlaContext* context = new XlaContext( - this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, options.is_entry_computation, - &options_.shape_representation_fn); + XlaContext* context = + new XlaContext(this, &builder, options_.allow_cpu_custom_calls, + options.resolve_compile_time_constants, + &options_.variable_representation_shape_fn); core::ScopedUnref context_unref(context); std::vector<XlaExpression> arg_expressions; @@ -695,22 +681,35 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; int num_computation_outputs; result->computation = std::make_shared<xla::XlaComputation>(); - result->outputs.resize(context->retvals().size()); TF_RETURN_IF_ERROR(BuildComputation( args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, - &num_nonconst_outputs, &result->outputs, &result->resource_updates)); + &num_nonconst_outputs, &result->resource_updates)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; + result->outputs.resize(context->retvals().size()); + for (std::vector<XlaExpression>::size_type i = 0; + i < context->retvals().size(); ++i) { + const XlaExpression& retval = context->retvals()[i]; + if (retval.has_constant_value()) { + OutputDescription& output = result->outputs[i]; + output.shape = retval.constant_value().shape(); + output.is_constant = true; + output.constant_value = retval.constant_value(); + } + } - // Compute the XLA output shape, if there is a computation with non-constant + // Compute the output shapes, if there is a computation with non-constant // outputs. - TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> computation_shape, - client()->GetComputationShape(*result->computation)); + auto computation_shape = client()->GetComputationShape(*result->computation); + if (!computation_shape.ok()) { + return computation_shape.status(); + } - result->xla_output_shape.Swap(computation_shape->mutable_result()); + result->xla_output_shape.Swap( + computation_shape.ValueOrDie()->mutable_result()); VLOG(2) << "XLA output shape: " << xla::ShapeUtil::HumanString(result->xla_output_shape); @@ -725,6 +724,23 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Tensorflow expects a major-to-minor order of results. xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); + // Converts the output shapes to TensorShapes. + int computation_output = 0; + for (std::vector<XlaExpression>::size_type i = 0; + i < context->retvals().size(); ++i) { + const XlaExpression& retval = context->retvals()[i]; + if (!retval.has_constant_value()) { + TF_RET_CHECK(computation_output < num_computation_outputs) + << "Computation has more outputs than expected"; + OutputDescription& output = result->outputs[i]; + output.is_constant = false; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape( + xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape, + computation_output), + &output.shape)); + ++computation_output; + } + } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 621fbc149a..ca6cd822ef 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -67,15 +67,6 @@ class XlaContext; // _Retval values are ordered by _Retval index, whereas kResource values are // ordered by the original _Arg position of the variable. // -// If a shape representation function is provided as part of -// XlaCompiler::CompileOptions, kParameter arguments and return values to an -// entry computation will be reshaped in accordance to the shape function. -// Arguments and return values to a non-entry computation are not reshaped. -// Variable resource arguments are passed and returned in reshaped form, even -// for non-entry computations. This feature allows TensorFlow to keep on-device -// tensors with a different shape to their representation inside the XLA -// computation. -// // In both inputs and outputs, kResource values are placed the end. When // emitting While loop bodies, we must ensure that the loop body has // identical input and output signatures. By moving variable values @@ -180,7 +171,7 @@ class XlaCompiler { }; struct OutputDescription { - // Type and shape of the output. The shape is the unflattened shape. + // Type and shape of the output. DataType type; TensorShape shape; @@ -215,12 +206,10 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector<int> input_mapping; - // Input shapes of the computation. If we are flattening inputs, these are - // the flattened shapes. + // Input shapes of the computation. std::vector<xla::Shape> xla_input_shapes; - // Output shape in XLA format. The output shape is always a tuple. If we - // are flattening outputs, these are the flattened shapes. + // Output shape in XLA format. The output shape is always a tuple. xla::Shape xla_output_shape; // TensorFlow shapes of outputs, together with the values of any @@ -241,8 +230,6 @@ class XlaCompiler { std::shared_ptr<xla::XlaComputation> computation; }; - typedef std::function<TensorShape(const TensorShape&, DataType)> - ShapeRepresentationFn; struct Options { // Name of the compilation device to use. Needs to be live only during // XlaCompiler's constructor. @@ -263,7 +250,8 @@ class XlaCompiler { // If set, the XLA representation of variables represented to XLA as the // shape given by this shape function. Variables are reshaped to this shape // on write, and reshaped to their original shape on read. - ShapeRepresentationFn shape_representation_fn; + std::function<TensorShape(const TensorShape&, DataType)> + variable_representation_shape_fn; // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation @@ -312,8 +300,7 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, - xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 5670545f9d..4382ffe6ba 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" @@ -751,7 +750,10 @@ TEST_F(XlaCompilerTest, Variables) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() { +// Tests a simple graph that reads and writes a variable, with a +// variable_representation_shape_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); @@ -762,15 +764,7 @@ xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() { auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1)); auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global())); - TF_RETURN_IF_ERROR(scope.ToGraph(graph.get())); - return std::move(graph); -} - -// Tests a simple graph that reads and writes a variable, with a -// shape_representation_fn passed to the compiler that flattens all -// variable tensors to vectors. -TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph()); + TF_ASSERT_OK(scope.ToGraph(graph.get())); // Builds a description of the arguments. std::vector<XlaCompiler::Argument> args(2); @@ -785,33 +779,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { // Compiles the graph. XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { + options.variable_representation_shape_fn = [](const TensorShape& shape, + DataType type) { return TensorShape({shape.num_elements()}); }; XlaCompiler compiler(options); - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = false; // Only reshape variables. - XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, &result)); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape, - client_->GetComputationShape(*result.computation)); - - ASSERT_EQ(program_shape->parameters_size(), 2); - EXPECT_TRUE( - xla::ShapeUtil::Compatible(program_shape->parameters(0), - xla::ShapeUtil::MakeShape(xla::S32, {2, 2}))); - EXPECT_TRUE(xla::ShapeUtil::Compatible( - program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); - EXPECT_TRUE(xla::ShapeUtil::Compatible( - program_shape->result(), - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}), - xla::ShapeUtil::MakeShape(xla::S32, {4})}))); + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); // Tests that the generated computation works. std::unique_ptr<xla::Literal> param0_literal = @@ -839,74 +815,5 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); } -TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) { - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph()); - - // Builds a description of the arguments. - std::vector<XlaCompiler::Argument> args(2); - args[0].kind = XlaCompiler::Argument::kParameter; - args[0].type = DT_INT32; - args[0].shape = TensorShape({2, 2}); - args[1].kind = XlaCompiler::Argument::kResource; - args[1].resource_kind = XlaResource::kVariable; - args[1].initialized = true; - args[1].type = DT_INT32; - args[1].shape = TensorShape({2, 2}); - - // Compiles the graph. - XlaCompiler::Options options = DefaultOptions(); - options.shape_representation_fn = [](const TensorShape& shape, - DataType type) { - return TensorShape({shape.num_elements()}); - }; - XlaCompiler compiler(options); - - XlaCompiler::CompileOptions compile_options; - compile_options.is_entry_computation = true; // Reshape args and retvals. - - XlaCompiler::CompilationResult result; - TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), - args, &result)); - - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape, - client_->GetComputationShape(*result.computation)); - - ASSERT_EQ(program_shape->parameters_size(), 2); - EXPECT_TRUE(xla::ShapeUtil::Compatible( - program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4}))); - EXPECT_TRUE(xla::ShapeUtil::Compatible( - program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4}))); - EXPECT_TRUE(xla::ShapeUtil::Compatible( - program_shape->result(), - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {4}), - xla::ShapeUtil::MakeShape(xla::S32, {4})}))); - - // Tests that the generated computation works. - std::unique_ptr<xla::Literal> param0_literal = - xla::Literal::CreateR1<int32>({4, 55, 1, -3}); - std::unique_ptr<xla::Literal> param1_literal = - xla::Literal::CreateR1<int32>({22, 11, 33, 404}); - std::unique_ptr<xla::GlobalData> param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr<xla::GlobalData> param1_data = - client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); - - std::unique_ptr<xla::GlobalData> actual = - client_ - ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) - .ConsumeValueOrDie(); - std::unique_ptr<xla::Literal> actual_literal = - client_->Transfer(*actual).ConsumeValueOrDie(); - - std::unique_ptr<xla::Literal> expected0 = - xla::Literal::CreateR1<int32>({27, 67, 35, 402}); - std::unique_ptr<xla::Literal> expected1 = - xla::Literal::CreateR1<int32>({26, 66, 34, 401}); - std::unique_ptr<xla::Literal> expected_literal = - xla::Literal::MakeTuple({expected0.get(), expected1.get()}); - EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal)); -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 098072d33c..3dd2d183f3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -65,30 +65,26 @@ void XlaContext::set_args(std::vector<XlaExpression> args) { XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, const std::function<TensorShape(const TensorShape&, DataType)>* - shape_representation_fn) + variable_representation_shape_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants), - is_entry_computation_(is_entry_computation), - shape_representation_fn_(shape_representation_fn) {} + variable_representation_shape_fn_(variable_representation_shape_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. void XlaContext::AddRetval(int retval_index, DataType type, - const TensorShape& shape, const xla::XlaOp& handle) { + const xla::XlaOp& handle) { VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; // Add the return value to the list being built up. if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - XlaExpression e; - e.set_handle(handle); - retvals_[retval_index] = Retval{type, shape, e}; + retvals_[retval_index].set_handle(handle); } Status XlaContext::AddConstRetval(int retval_index, DataType dtype, @@ -98,11 +94,13 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, if (retvals_.size() <= retval_index) { retvals_.resize(retval_index + 1); } - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - XlaExpression e; - e.set_constant_value(value); - retvals_[retval_index] = Retval{dtype, value.shape(), e}; + if (resolve_compile_time_constants_) { + Tensor value; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); + retvals_[retval_index].set_constant_value(std::move(value)); + } else { + retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal)); + } return Status::OK(); } @@ -119,9 +117,9 @@ Status XlaContext::CreateResource( return Status::OK(); } -TensorShape XlaContext::RepresentationShape(const TensorShape& shape, - DataType type) const { - return (*shape_representation_fn_)(shape, type); +TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, + DataType type) const { + return (*variable_representation_shape_fn_)(shape, type); } const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3ad2b2e44e..1136ffe507 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -42,13 +42,11 @@ class XlaContext : public ResourceBase { static XlaContext& Get(const OpKernelContext* ctx); static XlaContext& Get(const XlaOpKernelContext* ctx); - // Creates a new XlaContext. See the documentation on the class data fields - // for descriptions of the arguments. + // Creates a new XlaContext. XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, const std::function<TensorShape(const TensorShape&, DataType)>* - shape_representation_fn); + variable_representation_shape_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -60,26 +58,14 @@ class XlaContext : public ResourceBase { bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool resolve_compile_time_constants() const { - return resolve_compile_time_constants_; - } - bool is_entry_computation() const { return is_entry_computation_; } - const std::vector<XlaExpression>& args() const { return args_; } void set_args(std::vector<XlaExpression> args); - struct Retval { - DataType type; - TensorShape shape; - // An XlaExpression representing the Retval's value. - XlaExpression expression; - }; - const std::vector<Retval>& retvals() { return retvals_; } + const std::vector<XlaExpression>& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const TensorShape& shape, - const xla::XlaOp& handle); + void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle); // As for Retval, but for return values that are compile-time constants. Status AddConstRetval(int retval_index, DataType dtype, @@ -100,9 +86,9 @@ class XlaContext : public ResourceBase { } // Returns the XLA shape to be used to represent a variable of TF `shape` - // and `type`, or of an argument or return value of a top-level computation. - TensorShape RepresentationShape(const TensorShape& shape, - DataType type) const; + // and `type`. + TensorShape VariableRepresentationShape(const TensorShape& shape, + DataType type) const; // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a @@ -145,19 +131,15 @@ class XlaContext : public ResourceBase { std::vector<XlaExpression> args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector<Retval> retvals_; + std::vector<XlaExpression> retvals_; // Holds ownership of resources. The resources are not ordered. std::vector<std::unique_ptr<XlaResource>> resources_; - // Is this a top-level computation, or an inner computation (e.g., a while - // body)? - const bool is_entry_computation_; - // A function that describes how variable shapes should be represented // in XLA. Variable values will be reshaped to this shape. Must be non-null. const std::function<TensorShape(const TensorShape&, DataType)>* - shape_representation_fn_; + variable_representation_shape_fn_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map<DataType, xla::XlaComputation>; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 76c68d81af..2b65f4d5d5 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type, } XlaContext& xla_context = XlaContext::Get(context_); - TensorShape representation_shape = - xla_context.RepresentationShape(variable->shape(), variable->type()); + TensorShape representation_shape = xla_context.VariableRepresentationShape( + variable->shape(), variable->type()); if (representation_shape == variable->shape()) { *value = variable->value(); } else { @@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, XlaContext& xla_context = XlaContext::Get(context_); TensorShape representation_shape = - xla_context.RepresentationShape(shape, type); + xla_context.VariableRepresentationShape(shape, type); if (shape != representation_shape) { handle = builder()->Reshape(handle, representation_shape.dim_sizes()); } |