aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-05-15 10:33:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-15 10:36:11 -0700
commita7a3bb3df12c632b81bf1b23f8405f92a0c903c3 (patch)
tree5291bb0bed562ce40c5a13014c6c955c71e4e10a
parent2c83cddab0cd0de78f863e47d81b4427d6519eb7 (diff)
Automated g4 rollback of changelist 196683444
PiperOrigin-RevId: 196691101
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc8
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc16
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc9
-rw-r--r--tensorflow/compiler/jit/xla_device.cc43
-rw-r--r--tensorflow/compiler/jit/xla_device.h33
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc49
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h14
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc3
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc5
-rw-r--r--tensorflow/compiler/tests/BUILD124
-rw-r--r--tensorflow/compiler/tests/xla_device_gpu_test.py48
-rw-r--r--tensorflow/compiler/tests/xla_device_test.py39
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/retval_op.cc20
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.cc102
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler.h25
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiler_test.cc111
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc30
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h36
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc6
20 files changed, 247 insertions, 475 deletions
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index fee46280e9..868d752927 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -551,16 +551,14 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
- "%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
+ "%dot.0.2 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
auto add_profile_line = HasSubstr(
- "%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
+ "%add.0.5 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
auto tuple_profile_line = HasSubstr(
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
- "%dot.0.4, f32[2,2]{1,0} %add.0.6)");
- auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
- auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
+ "%dot.0.2, f32[2,2]{1,0} %add.0.5)");
EXPECT_THAT(hlo_profile_lines,
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index 9d856346ec..86a9fd3b8e 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -112,7 +112,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
- const XlaDevice::Metadata* metadata = nullptr;
+ const XlaDevice::Metadata* metadata;
Status s = XlaDevice::GetMetadata(ctx, &metadata);
bool allocate_xla_tensors = s.ok();
@@ -153,9 +153,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
options.device_allocator = xla_allocator;
- if (metadata) {
- options.shape_representation_fn = metadata->shape_representation_fn();
- }
+ // TODO(b/77671268): We don't set variable_representation_shape_fn here. This
+ // is restricted to Variables, but we need something like this to apply to
+ // normal Tensors too.
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
@@ -164,11 +164,9 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
for (int i : constants_) {
constant_args.insert({i, ctx->input(i)});
}
- XlaCompiler::CompileOptions compile_options;
- compile_options.is_entry_computation = true;
- OP_REQUIRES_OK(
- ctx, cache->Compile(options, function_, constant_args, variables, ctx,
- &kernel, &executable, &compile_options));
+ OP_REQUIRES_OK(ctx, cache->Compile(options, function_, constant_args,
+ variables, ctx, &kernel, &executable,
+ /*compile_options=*/nullptr));
VLOG(1) << "Executing XLA Computation...";
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index ea9e036604..bc07dbd7bd 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -50,11 +50,10 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options,
(void)registrations;
std::unique_ptr<XlaDevice> device;
- TF_RETURN_IF_ERROR(
- XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options,
- name_prefix, registration,
- /*transfer_as_literal=*/false,
- /*shape_representation_fn=*/{}, &device));
+ TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0,
+ DEVICE_CPU_XLA_JIT, options, name_prefix,
+ registration,
+ /*transfer_as_literal=*/false, &device));
devices->push_back(device.release());
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 9ee5b04e80..70263b1ff9 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -110,9 +110,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
const string& jit_device_name, const SessionOptions& options,
const string& name_prefix,
const XlaOpRegistry::DeviceRegistration& registration,
- bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
- std::unique_ptr<XlaDevice>* device) {
+ bool transfer_as_literal, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal;
@@ -131,19 +129,17 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
DeviceType(device_name), Bytes(16ULL << 30), DeviceLocality(),
strings::StrCat("device: ", device_name, " device"));
- device->reset(new XlaDevice(
- options, attrs, device_ordinal, DeviceType(jit_device_name),
- platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
+ device->reset(new XlaDevice(options, attrs, device_ordinal,
+ DeviceType(jit_device_name),
+ platform.ValueOrDie(), transfer_as_literal));
return Status::OK();
}
-XlaDevice::Metadata::Metadata(
- int device_ordinal, se::Platform* platform, const DeviceType& device_type,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+XlaDevice::Metadata::Metadata(int device_ordinal, se::Platform* platform,
+ const DeviceType& device_type)
: device_ordinal_(device_ordinal),
device_type_(device_type),
- platform_(platform),
- shape_representation_fn_(std::move(shape_representation_fn)) {}
+ platform_(platform) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@@ -174,20 +170,17 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const {
return Status::OK();
}
-XlaDevice::XlaDevice(
- const SessionOptions& options, const DeviceAttributes& attrs,
- int device_ordinal, const DeviceType& jit_device_name,
- se::Platform* platform, bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
+XlaDevice::XlaDevice(const SessionOptions& options,
+ const DeviceAttributes& attrs, int device_ordinal,
+ const DeviceType& jit_device_name, se::Platform* platform,
+ bool transfer_as_literal)
: LocalDevice(options, attrs),
- xla_metadata_(device_ordinal, platform, jit_device_name,
- shape_representation_fn),
+ xla_metadata_(device_ordinal, platform, jit_device_name),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
- transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(shape_representation_fn) {
+ transfer_as_literal_(transfer_as_literal) {
VLOG(1) << "Created XLA device " << jit_device_name;
}
@@ -239,8 +232,8 @@ Status XlaDevice::CreateAndSetGpuDeviceInfo() {
// gpu_device_info_->default_context.
gpu_device_info_ = absl::make_unique<GpuDeviceInfo>();
gpu_device_info_->stream = stream;
- gpu_device_info_->default_context = new XlaDeviceContext(
- stream, client(), transfer_as_literal_, shape_representation_fn_);
+ gpu_device_info_->default_context =
+ new XlaDeviceContext(stream, client(), transfer_as_literal_);
set_tensorflow_gpu_device_info(gpu_device_info_.get());
}
@@ -254,8 +247,7 @@ Status XlaDevice::FillContextMap(const Graph* graph,
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
// Call GetAllocator for the side-effect of ensuring the allocator is created.
GetAllocator({});
- auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_,
- shape_representation_fn_);
+ auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
ctx->Ref();
@@ -302,8 +294,7 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor copy(GetAllocator(alloc_attrs), parsed.dtype(), parsed.shape());
Notification n;
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
- XlaTransferManager manager(stream, client(), transfer_as_literal_,
- shape_representation_fn_);
+ XlaTransferManager manager(stream, client(), transfer_as_literal_);
manager.CopyCPUTensorToDevice(&parsed, this, &copy,
[&n, &status](const Status& s) {
status = s;
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index d5d345d43b..3ae87308cc 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -17,7 +17,8 @@ limitations under the License.
// runtime.
//
// Operators assigned to an XlaDevice are compiled into XLA computations.
-// Tensors on an XlaDevice are thin wrappers around XLA ScopedShapedBuffers.
+// Tensors on an XlaDevice are thin wrappers around XLA GlobalDataHandles; state
+// is managed by XLA.
//
// XlaDevice is instantiated separately for each XLA backend (e.g., CPU or GPU),
// under different names (e.g., XLA_CPU or XLA_GPU).
@@ -26,7 +27,6 @@ limitations under the License.
#define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_
#include "tensorflow/compiler/jit/xla_tensor.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/common_runtime/device_factory.h"
@@ -50,8 +50,7 @@ class XlaDevice : public LocalDevice {
class Metadata {
public:
Metadata(int device_ordinal, se::Platform* platform,
- const DeviceType& device_type,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ const DeviceType& device_type);
// The index of the device on this host.
int device_ordinal() const;
@@ -59,15 +58,11 @@ class XlaDevice : public LocalDevice {
se::Platform* platform() const;
xla::LocalClient* client() const;
const DeviceType& jit_device_type() const;
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn() const {
- return shape_representation_fn_;
- }
private:
const int device_ordinal_;
const DeviceType device_type_;
se::Platform* platform_; // Not owned.
- XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@@ -81,19 +76,16 @@ class XlaDevice : public LocalDevice {
// 'transfer_as_literal' is true if device<->host transfers must be done using
// XLA's TransferLiteral{To,From}Device interface. If false, we can use
// ThenMemcpy instead.
- static Status Create(
- const string& platform_name, const string& device_name,
- int device_ordinal, const string& jit_device_name,
- const SessionOptions& options, const string& name_prefix,
- const XlaOpRegistry::DeviceRegistration& registration,
- bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
- std::unique_ptr<XlaDevice>* device);
+ static Status Create(const string& platform_name, const string& device_name,
+ int device_ordinal, const string& jit_device_name,
+ const SessionOptions& options, const string& name_prefix,
+ const XlaOpRegistry::DeviceRegistration& registration,
+ bool transfer_as_literal,
+ std::unique_ptr<XlaDevice>* device);
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
- se::Platform* platform, bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn);
+ se::Platform* platform, bool transfer_as_literal);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
@@ -124,8 +116,8 @@ class XlaDevice : public LocalDevice {
// The name of the device that is used to compile Ops for this XlaDevice.
DeviceType jit_device_name_;
// Memory allocator associated with this device.
- Allocator* xla_allocator_; // Not owned.
- se::Platform* platform_; // Not owned.
+ Allocator* xla_allocator_; // Not owned.
+ se::Platform* platform_; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
@@ -134,7 +126,6 @@ class XlaDevice : public LocalDevice {
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
- XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// If set, holds default device context (that we must Unref)
// and its stream.
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index ff30b62bad..bf8c1886a0 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -47,14 +47,13 @@ void XlaDeviceAllocator::DeallocateRaw(void* ptr) {
void XlaDeviceAllocator::GetStats(AllocatorStats* stats) { stats->Clear(); }
-XlaTransferManager::XlaTransferManager(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+XlaTransferManager::XlaTransferManager(se::Stream* stream,
+ xla::LocalClient* client,
+ bool transfer_as_literal)
: stream_(stream),
client_(client),
transfer_manager_(client->backend().transfer_manager()),
- transfer_as_literal_(transfer_as_literal),
- shape_representation_fn_(std::move(shape_representation_fn)) {}
+ transfer_as_literal_(transfer_as_literal) {}
Status XlaTransferManager::TransferLiteralToDevice(
const Tensor& host_tensor, Tensor* device_tensor) const {
@@ -77,15 +76,7 @@ Status XlaTransferManager::TransferLiteralFromDevice(
transfer_manager_->TransferLiteralFromDevice(
stream_->parent(), shaped_buffer));
VLOG(1) << "Transfer from device as literal: " << literal->ToString();
- Tensor tensor;
- TF_RETURN_IF_ERROR(
- LiteralToHostTensor(*literal, host_tensor->dtype(), &tensor));
- // Reshape the tensor back to its declared shape.
- if (!host_tensor->CopyFrom(tensor, device_tensor.shape())) {
- return errors::Internal(
- "Tensor::CopyFrom failed when copying from XLA device to CPU");
- }
- return Status::OK();
+ return LiteralToHostTensor(*literal, host_tensor->dtype(), host_tensor);
}
void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
@@ -105,17 +96,9 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
XlaTensor* xla_tensor = XlaTensor::FromTensor(device_tensor);
CHECK(xla_tensor);
-
- TensorShape shape;
- if (shape_representation_fn_) {
- shape = shape_representation_fn_(device_tensor->shape(),
- device_tensor->dtype());
- } else {
- shape = device_tensor->shape();
- }
if (!xla_tensor->has_shaped_buffer()) {
Status s = xla_tensor->AllocateShapedBuffer(
- device_tensor->dtype(), shape, client_,
+ device_tensor->dtype(), device_tensor->shape(), client_,
stream_->parent()->device_ordinal());
if (!s.ok()) {
done(s);
@@ -123,18 +106,12 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
}
}
+ se::DeviceMemoryBase dev_dst_ptr =
+ XlaTensor::DeviceMemoryFromTensor(*device_tensor);
Status status;
if (transfer_as_literal_) {
- Tensor reshaped_cpu_tensor;
- if (!reshaped_cpu_tensor.CopyFrom(*cpu_tensor, shape)) {
- done(errors::Internal(
- "Tensor::CopyFrom failed when copying from CPU to XLA device"));
- return;
- }
- status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
+ status = TransferLiteralToDevice(*cpu_tensor, device_tensor);
} else {
- se::DeviceMemoryBase dev_dst_ptr =
- XlaTensor::DeviceMemoryFromTensor(*device_tensor);
stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes);
// TODO(hpucha): Make this asynchronous.
Status block_status = stream_->BlockHostUntilDone();
@@ -194,11 +171,9 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor,
done(Status::OK());
}
-XlaDeviceContext::XlaDeviceContext(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
- : manager_(stream, client, transfer_as_literal,
- std::move(shape_representation_fn)) {}
+XlaDeviceContext::XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
+ bool transfer_as_literal)
+ : manager_(stream, client, transfer_as_literal) {}
void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
Device* device,
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index 9af9655868..d7f5f1d208 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -19,7 +19,6 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/jit/xla_tensor.h"
-#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/allocator.h"
@@ -46,9 +45,8 @@ class XlaDeviceAllocator : public Allocator {
// Helper class for managing data transfers between host and XLA devices.
class XlaTransferManager {
public:
- explicit XlaTransferManager(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client,
+ bool transfer_as_literal);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor, StatusCallback done) const;
@@ -71,8 +69,7 @@ class XlaTransferManager {
// Transfer manager, for marshalling data to and from the device.
xla::TransferManager* transfer_manager_;
// True if we must use XLA's TransferManager for correct device transfers.
- const bool transfer_as_literal_;
- const XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
+ bool transfer_as_literal_;
};
// DeviceContext for operators assigned to XlaDevice devices. The
@@ -80,9 +77,8 @@ class XlaTransferManager {
// wraps the methods in XlaTransferManager.
class XlaDeviceContext : public DeviceContext {
public:
- explicit XlaDeviceContext(
- se::Stream* stream, xla::LocalClient* client, bool transfer_as_literal,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn);
+ explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
+ bool transfer_as_literal);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 26842fbe5c..a8afbf9dcd 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -48,8 +48,7 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
Status status =
XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options,
name_prefix, registration,
- /*transfer_as_literal=*/false,
- /*shape_representation_fn=*/{}, &device);
+ /*transfer_as_literal=*/false, &device);
if (!status.ok()) {
// Treat failures as non-fatal; there might not be a GPU in the machine.
VLOG(1) << "Failed to create XLA_GPU device: " << status;
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index d0c7a93651..6a0f557627 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -195,6 +195,11 @@ void XlaComputationLaunchContext::PopulateOutputs(
OP_REQUIRES_OK(
ctx, ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
+ if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
+ OP_REQUIRES_OK(ctx, xla_tensor->AllocateShapedBuffer(
+ const_tensor.dtype(), const_tensor.shape(),
+ client_, stream->parent()->device_ordinal()));
+ }
Device* device = dynamic_cast<Device*>(ctx->device());
OP_REQUIRES(ctx, device != nullptr,
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 213ab95a12..96dfc8d8f1 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -42,7 +42,7 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
"//tensorflow/python:random_seed",
"//tensorflow/python:session",
@@ -58,7 +58,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@@ -72,7 +72,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@@ -93,7 +93,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@@ -111,7 +111,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:bitwise_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@@ -127,7 +127,7 @@ tf_xla_py_test(
tags = ["optonly"],
deps = [
":xla_test",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
@@ -141,7 +141,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@@ -156,7 +156,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@@ -170,7 +170,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@@ -184,7 +184,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
@@ -209,7 +209,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:array_ops_gen",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradient_checker",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
@@ -225,7 +225,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@@ -241,7 +241,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@@ -263,7 +263,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@@ -291,7 +291,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -307,7 +307,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -326,7 +326,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
@@ -346,7 +346,7 @@ tf_xla_py_test(
"//tensorflow/contrib/signal:signal_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:extra_py_tests_deps",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
"//tensorflow/python:spectral_ops",
],
@@ -360,7 +360,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -372,7 +372,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@@ -388,7 +388,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -403,7 +403,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:image_ops",
"//tensorflow/python:platform_test",
],
@@ -431,7 +431,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@@ -446,7 +446,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -458,7 +458,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@@ -472,7 +472,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@@ -485,7 +485,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -498,7 +498,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@@ -513,7 +513,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
@@ -530,7 +530,7 @@ tf_xla_py_test(
],
deps = [
":xla_test",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
"//tensorflow/python:random_ops",
],
@@ -545,7 +545,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@@ -561,7 +561,7 @@ tf_xla_py_test(
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@@ -574,7 +574,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
],
)
@@ -586,7 +586,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -598,7 +598,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
@@ -613,7 +613,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@@ -626,7 +626,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:platform_test",
@@ -641,7 +641,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@@ -657,7 +657,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -670,7 +670,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/contrib/stateless",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -684,7 +684,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@@ -703,7 +703,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
@@ -716,7 +716,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
@@ -730,7 +730,7 @@ tf_xla_py_test(
srcs = ["fused_batchnorm_test.py"],
deps = [
":xla_test",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn",
@@ -749,7 +749,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:math_ops_gen",
"//tensorflow/python:nn_ops",
@@ -768,7 +768,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/compiler/tf2xla/python:xla",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
@@ -783,7 +783,7 @@ tf_xla_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:data_flow_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -795,7 +795,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -808,34 +808,21 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
-tf_xla_py_test(
+cuda_py_test(
name = "xla_device_test",
size = "small",
srcs = ["xla_device_test.py"],
- tags = ["optonly"],
- deps = [
- ":xla_test",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
- "//tensorflow/python:platform_test",
- ],
-)
-
-cuda_py_test(
- name = "xla_device_gpu_test",
- size = "small",
- srcs = ["xla_device_gpu_test.py"],
additional_deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
],
)
@@ -852,6 +839,7 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
@@ -899,7 +887,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
@@ -914,7 +902,7 @@ cuda_py_test(
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
@@ -952,7 +940,7 @@ tf_xla_py_test(
srcs = ["fake_quant_ops_test.py"],
deps = [
":xla_test",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
@@ -964,7 +952,7 @@ tf_xla_py_test(
deps = [
":xla_test",
"//tensorflow/python:array_ops",
- "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
diff --git a/tensorflow/compiler/tests/xla_device_gpu_test.py b/tensorflow/compiler/tests/xla_device_gpu_test.py
deleted file mode 100644
index 1e30ebd55d..0000000000
--- a/tensorflow/compiler/tests/xla_device_gpu_test.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test cases for XLA devices."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.python.client import session as session_lib
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-
-
-class XlaDeviceGpuTest(test.TestCase):
-
- def testCopiesToAndFromGpuWork(self):
- """Tests that copies between GPU and XLA devices work."""
- if not test.is_gpu_available():
- return
-
- with session_lib.Session() as sess:
- x = array_ops.placeholder(dtypes.float32, [2])
- with ops.device("GPU"):
- y = x * 2
- with ops.device("device:XLA_CPU:0"):
- z = y * y
- with ops.device("GPU"):
- w = y + z
- result = sess.run(w, {x: [1.5, 0.5]})
- self.assertAllClose(result, [12., 2.], rtol=1e-3)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/compiler/tests/xla_device_test.py b/tensorflow/compiler/tests/xla_device_test.py
index b707bd0963..f5c228f830 100644
--- a/tensorflow/compiler/tests/xla_device_test.py
+++ b/tensorflow/compiler/tests/xla_device_test.py
@@ -1,4 +1,4 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,33 +18,30 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class XlaDeviceTest(XLATestCase):
+class XlaDeviceTest(test.TestCase):
def testCopies(self):
- """Tests that copies onto and off XLA devices work."""
- shapes = [[0], [1], [1, 0], [1024, 0], [1024, 1], [3, 777], [777, 3],
- [16384, 1], [1, 16384], [1, 20000, 1, 1]]
- for dtype in self.numeric_types:
- for shape in shapes:
- with self.test_session() as sess:
- with ops.device("CPU"):
- x = array_ops.placeholder(dtype, shape)
- with self.test_scope():
- y = x + x
- with ops.device("CPU"):
- z = array_ops.identity(y)
-
- inputs = np.random.randint(-100, 100, shape).astype(dtype)
- result = sess.run(z, {x: inputs})
- self.assertAllCloseAccordingToType(result, inputs + inputs)
+ """Tests that copies between GPU and XLA devices work."""
+ if not test.is_gpu_available():
+ return
+
+ with session_lib.Session() as sess:
+ x = array_ops.placeholder(dtypes.float32, [2])
+ with ops.device("GPU"):
+ y = x * 2
+ with ops.device("device:XLA_CPU:0"):
+ z = y * y
+ with ops.device("GPU"):
+ w = y + z
+ result = sess.run(w, {x: [1.5, 0.5]})
+ self.assertAllClose(result, [12., 2.], rtol=1e-3)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index cd57452302..4fca51f54d 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -325,7 +325,6 @@ tf_cc_test(
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/service:cpu_plugin",
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index a567226f55..70547290ea 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -55,24 +55,18 @@ class RetvalOp : public XlaOpKernel {
}
XlaContext& tc = XlaContext::Get(ctx);
- if (tc.resolve_compile_time_constants() &&
- (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
+ if (input_shape.num_elements() == 0 || is_constant.ValueOrDie()) {
xla::Literal literal;
OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
} else {
- TensorShape shape = ctx->InputShape(0);
- TensorShape representation_shape =
- tc.is_entry_computation()
- ? tc.RepresentationShape(shape, ctx->input_type(0))
- : shape;
// The core from which a return value is returned depends on the core
- // assignment of the input to the retval. Since we can't change the core
- // assignment of <input> as this point, we must always introduce a
- // reshape here, even if the shape does not change.
- xla::XlaOp reshape =
- ctx->builder()->Reshape(input, representation_shape.dim_sizes());
- tc.AddRetval(index_, dtype_, shape, reshape);
+ // assignment of the input to the retval .Since we can't change the core
+ // assignment of <input> as this point, create a tuple/get-tuple-element
+ // combination so that the core will be set on them.
+ auto tuple_elem =
+ ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0);
+ tc.AddRetval(index_, dtype_, tuple_elem);
}
}
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 962e5340c0..3d1946c332 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -15,9 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include <deque>
#include <numeric>
-#include <vector>
+#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h"
#include "tensorflow/compiler/tf2xla/graph_compiler.h"
@@ -27,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/executor.h"
@@ -38,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
@@ -108,9 +111,9 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
flib_runtime_ = pflr_->GetFLR(device_->name());
// The default variable representation shape is the identity function.
- if (!options_.shape_representation_fn) {
- options_.shape_representation_fn = [](const TensorShape& shape,
- DataType type) { return shape; };
+ if (!options_.variable_representation_shape_fn) {
+ options_.variable_representation_shape_fn =
+ [](const TensorShape& shape, DataType type) { return shape; };
}
}
@@ -227,25 +230,20 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options,
// Computes the XLA shape for argument 'arg'.
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
- bool is_entry_computation,
xla::Shape* xla_shape) {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
- LOG(FATAL) << "Unreachable case";
- case XlaCompiler::Argument::kParameter: {
- TensorShape shape =
- is_entry_computation
- ? options_.shape_representation_fn(arg.shape, arg.type)
- : arg.shape;
- return TensorShapeToXLAShape(arg.type, shape, xla_shape);
- }
+ return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(),
+ xla_shape);
+ case XlaCompiler::Argument::kParameter:
+ return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape);
case XlaCompiler::Argument::kResource: {
TF_RET_CHECK(arg.initialized);
switch (arg.resource_kind) {
case XlaResource::kVariable: {
TensorShape representation_shape =
- options_.shape_representation_fn(arg.shape, arg.type);
+ options_.variable_representation_shape_fn(arg.shape, arg.type);
return TensorShapeToXLAShape(arg.type, representation_shape,
xla_shape);
}
@@ -339,25 +337,16 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
Status BuildComputation(
const std::vector<XlaCompiler::Argument>& args,
const std::vector<int>& arg_cores,
- const std::vector<XlaContext::Retval>& retvals,
+ const std::vector<XlaExpression>& retvals,
const std::vector<std::unique_ptr<XlaResource>>& resources,
bool return_updated_values_for_all_resources, xla::XlaBuilder* builder,
xla::XlaComputation* computation, int* num_computation_outputs,
int* num_nonconst_outputs,
- std::vector<XlaCompiler::OutputDescription>* outputs,
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size());
- for (int i = 0; i < retvals.size(); ++i) {
- XlaCompiler::OutputDescription& output = (*outputs)[i];
- output.type = retvals[i].type;
- output.shape = retvals[i].shape;
- const XlaExpression& retval = retvals[i].expression;
- if (retval.has_constant_value()) {
- output.is_constant = true;
- output.constant_value = retval.constant_value();
- } else {
- output.is_constant = false;
+ for (const XlaExpression& retval : retvals) {
+ if (!retval.has_constant_value()) {
elems.push_back(retval.handle());
}
}
@@ -501,8 +490,8 @@ Status XlaCompiler::BuildArguments(
std::vector<xla::Shape> arg_shapes(input_mapping->size());
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
// Computes the shapes of non-constant arguments.
- TF_RETURN_IF_ERROR(XLAShapeForArgument(
- args[(*input_mapping)[i]], is_entry_computation, &arg_shapes[i]));
+ TF_RETURN_IF_ERROR(
+ XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i]));
}
if (use_tuple_arg) {
@@ -578,8 +567,7 @@ Status XlaCompiler::BuildArguments(
builder->ClearOpMetadata();
- // Fill in the handles in non-constant arguments, and reshape parameters
- // back to their correct shapes.
+ // Fill in the handles in non-constant arguments.
VLOG(2) << "XLA computation inputs:";
for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
const XlaCompiler::Argument& arg = args[input_mapping->at(i)];
@@ -598,9 +586,7 @@ Status XlaCompiler::BuildArguments(
break;
}
case XlaCompiler::Argument::kParameter:
- // Reshape parameters back to their correct shapes.
- arg_expression.set_handle(
- builder->Reshape(arg_handles[i], arg.shape.dim_sizes()));
+ arg_expression.set_handle(arg_handles[i]);
break;
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kInvalid:
@@ -675,10 +661,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
FunctionalizeControlFlow(graph.get(), local_flib_def_.get()));
xla::XlaBuilder builder(name);
- XlaContext* context = new XlaContext(
- this, &builder, options_.allow_cpu_custom_calls,
- options.resolve_compile_time_constants, options.is_entry_computation,
- &options_.shape_representation_fn);
+ XlaContext* context =
+ new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
+ options.resolve_compile_time_constants,
+ &options_.variable_representation_shape_fn);
core::ScopedUnref context_unref(context);
std::vector<XlaExpression> arg_expressions;
@@ -695,22 +681,35 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
int num_nonconst_outputs;
int num_computation_outputs;
result->computation = std::make_shared<xla::XlaComputation>();
- result->outputs.resize(context->retvals().size());
TF_RETURN_IF_ERROR(BuildComputation(
args, arg_cores, context->retvals(), context->resources(),
options.return_updated_values_for_all_resources, &builder,
result->computation.get(), &num_computation_outputs,
- &num_nonconst_outputs, &result->outputs, &result->resource_updates));
+ &num_nonconst_outputs, &result->resource_updates));
VLOG(2) << "Outputs: total: " << context->retvals().size()
<< " nonconstant: " << num_nonconst_outputs;
+ result->outputs.resize(context->retvals().size());
+ for (std::vector<XlaExpression>::size_type i = 0;
+ i < context->retvals().size(); ++i) {
+ const XlaExpression& retval = context->retvals()[i];
+ if (retval.has_constant_value()) {
+ OutputDescription& output = result->outputs[i];
+ output.shape = retval.constant_value().shape();
+ output.is_constant = true;
+ output.constant_value = retval.constant_value();
+ }
+ }
- // Compute the XLA output shape, if there is a computation with non-constant
+ // Compute the output shapes, if there is a computation with non-constant
// outputs.
- TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::ProgramShape> computation_shape,
- client()->GetComputationShape(*result->computation));
+ auto computation_shape = client()->GetComputationShape(*result->computation);
+ if (!computation_shape.ok()) {
+ return computation_shape.status();
+ }
- result->xla_output_shape.Swap(computation_shape->mutable_result());
+ result->xla_output_shape.Swap(
+ computation_shape.ValueOrDie()->mutable_result());
VLOG(2) << "XLA output shape: "
<< xla::ShapeUtil::HumanString(result->xla_output_shape);
@@ -725,6 +724,23 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
// Tensorflow expects a major-to-minor order of results.
xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape);
+ // Converts the output shapes to TensorShapes.
+ int computation_output = 0;
+ for (std::vector<XlaExpression>::size_type i = 0;
+ i < context->retvals().size(); ++i) {
+ const XlaExpression& retval = context->retvals()[i];
+ if (!retval.has_constant_value()) {
+ TF_RET_CHECK(computation_output < num_computation_outputs)
+ << "Computation has more outputs than expected";
+ OutputDescription& output = result->outputs[i];
+ output.is_constant = false;
+ TF_RETURN_IF_ERROR(XLAShapeToTensorShape(
+ xla::ShapeUtil::GetTupleElementShape(result->xla_output_shape,
+ computation_output),
+ &output.shape));
+ ++computation_output;
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index 621fbc149a..ca6cd822ef 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -67,15 +67,6 @@ class XlaContext;
// _Retval values are ordered by _Retval index, whereas kResource values are
// ordered by the original _Arg position of the variable.
//
-// If a shape representation function is provided as part of
-// XlaCompiler::CompileOptions, kParameter arguments and return values to an
-// entry computation will be reshaped in accordance to the shape function.
-// Arguments and return values to a non-entry computation are not reshaped.
-// Variable resource arguments are passed and returned in reshaped form, even
-// for non-entry computations. This feature allows TensorFlow to keep on-device
-// tensors with a different shape to their representation inside the XLA
-// computation.
-//
// In both inputs and outputs, kResource values are placed the end. When
// emitting While loop bodies, we must ensure that the loop body has
// identical input and output signatures. By moving variable values
@@ -180,7 +171,7 @@ class XlaCompiler {
};
struct OutputDescription {
- // Type and shape of the output. The shape is the unflattened shape.
+ // Type and shape of the output.
DataType type;
TensorShape shape;
@@ -215,12 +206,10 @@ class XlaCompiler {
// original arguments, and are not necessarily in the same order.)
std::vector<int> input_mapping;
- // Input shapes of the computation. If we are flattening inputs, these are
- // the flattened shapes.
+ // Input shapes of the computation.
std::vector<xla::Shape> xla_input_shapes;
- // Output shape in XLA format. The output shape is always a tuple. If we
- // are flattening outputs, these are the flattened shapes.
+ // Output shape in XLA format. The output shape is always a tuple.
xla::Shape xla_output_shape;
// TensorFlow shapes of outputs, together with the values of any
@@ -241,8 +230,6 @@ class XlaCompiler {
std::shared_ptr<xla::XlaComputation> computation;
};
- typedef std::function<TensorShape(const TensorShape&, DataType)>
- ShapeRepresentationFn;
struct Options {
// Name of the compilation device to use. Needs to be live only during
// XlaCompiler's constructor.
@@ -263,7 +250,8 @@ class XlaCompiler {
// If set, the XLA representation of variables represented to XLA as the
// shape given by this shape function. Variables are reshaped to this shape
// on write, and reshaped to their original shape on read.
- ShapeRepresentationFn shape_representation_fn;
+ std::function<TensorShape(const TensorShape&, DataType)>
+ variable_representation_shape_fn;
// If not nullptr, populate_resource_manager is called with the
// compilation device's resource manager when the compilation
@@ -312,8 +300,7 @@ class XlaCompiler {
// Returns the shape of the XLA parameter for an argument 'arg'.
// See the class comment for more details about the argument passing
// convention.
- Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
- xla::Shape* xla_shape);
+ Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape);
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.
diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
index 5670545f9d..4382ffe6ba 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/common_shape_fns.h"
@@ -751,7 +750,10 @@ TEST_F(XlaCompilerTest, Variables) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
-xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
+// Tests a simple graph that reads and writes a variable, with a
+// variable_representation_shape_fn passed to the compiler that flattens all
+// variable tensors to vectors.
+TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
@@ -762,15 +764,7 @@ xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
- TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
- return std::move(graph);
-}
-
-// Tests a simple graph that reads and writes a variable, with a
-// shape_representation_fn passed to the compiler that flattens all
-// variable tensors to vectors.
-TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
+ TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
@@ -785,33 +779,15 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
// Compiles the graph.
XlaCompiler::Options options = DefaultOptions();
- options.shape_representation_fn = [](const TensorShape& shape,
- DataType type) {
+ options.variable_representation_shape_fn = [](const TensorShape& shape,
+ DataType type) {
return TensorShape({shape.num_elements()});
};
XlaCompiler compiler(options);
- XlaCompiler::CompileOptions compile_options;
- compile_options.is_entry_computation = false; // Only reshape variables.
-
XlaCompiler::CompilationResult result;
- TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
- args, &result));
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
- client_->GetComputationShape(*result.computation));
-
- ASSERT_EQ(program_shape->parameters_size(), 2);
- EXPECT_TRUE(
- xla::ShapeUtil::Compatible(program_shape->parameters(0),
- xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
- EXPECT_TRUE(xla::ShapeUtil::Compatible(
- program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
- EXPECT_TRUE(xla::ShapeUtil::Compatible(
- program_shape->result(),
- xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
- xla::ShapeUtil::MakeShape(xla::S32, {4})})));
+ TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
+ std::move(graph), args, &result));
// Tests that the generated computation works.
std::unique_ptr<xla::Literal> param0_literal =
@@ -839,74 +815,5 @@ TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
}
-TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
-
- // Builds a description of the arguments.
- std::vector<XlaCompiler::Argument> args(2);
- args[0].kind = XlaCompiler::Argument::kParameter;
- args[0].type = DT_INT32;
- args[0].shape = TensorShape({2, 2});
- args[1].kind = XlaCompiler::Argument::kResource;
- args[1].resource_kind = XlaResource::kVariable;
- args[1].initialized = true;
- args[1].type = DT_INT32;
- args[1].shape = TensorShape({2, 2});
-
- // Compiles the graph.
- XlaCompiler::Options options = DefaultOptions();
- options.shape_representation_fn = [](const TensorShape& shape,
- DataType type) {
- return TensorShape({shape.num_elements()});
- };
- XlaCompiler compiler(options);
-
- XlaCompiler::CompileOptions compile_options;
- compile_options.is_entry_computation = true; // Reshape args and retvals.
-
- XlaCompiler::CompilationResult result;
- TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
- args, &result));
-
- TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
- client_->GetComputationShape(*result.computation));
-
- ASSERT_EQ(program_shape->parameters_size(), 2);
- EXPECT_TRUE(xla::ShapeUtil::Compatible(
- program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
- EXPECT_TRUE(xla::ShapeUtil::Compatible(
- program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
- EXPECT_TRUE(xla::ShapeUtil::Compatible(
- program_shape->result(),
- xla::ShapeUtil::MakeTupleShape(
- {xla::ShapeUtil::MakeShape(xla::S32, {4}),
- xla::ShapeUtil::MakeShape(xla::S32, {4})})));
-
- // Tests that the generated computation works.
- std::unique_ptr<xla::Literal> param0_literal =
- xla::Literal::CreateR1<int32>({4, 55, 1, -3});
- std::unique_ptr<xla::Literal> param1_literal =
- xla::Literal::CreateR1<int32>({22, 11, 33, 404});
- std::unique_ptr<xla::GlobalData> param0_data =
- client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
- std::unique_ptr<xla::GlobalData> param1_data =
- client_->TransferToServer(*param1_literal).ConsumeValueOrDie();
-
- std::unique_ptr<xla::GlobalData> actual =
- client_
- ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
- .ConsumeValueOrDie();
- std::unique_ptr<xla::Literal> actual_literal =
- client_->Transfer(*actual).ConsumeValueOrDie();
-
- std::unique_ptr<xla::Literal> expected0 =
- xla::Literal::CreateR1<int32>({27, 67, 35, 402});
- std::unique_ptr<xla::Literal> expected1 =
- xla::Literal::CreateR1<int32>({26, 66, 34, 401});
- std::unique_ptr<xla::Literal> expected_literal =
- xla::Literal::MakeTuple({expected0.get(), expected1.get()});
- EXPECT_TRUE(xla::LiteralTestUtil::Equal(*expected_literal, *actual_literal));
-}
-
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 098072d33c..3dd2d183f3 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -65,30 +65,26 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
XlaContext::XlaContext(
XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
- bool is_entry_computation,
const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn)
+ variable_representation_shape_fn)
: compiler_(compiler),
builder_(builder),
allow_cpu_custom_calls_(allow_cpu_custom_calls),
resolve_compile_time_constants_(resolve_compile_time_constants),
- is_entry_computation_(is_entry_computation),
- shape_representation_fn_(shape_representation_fn) {}
+ variable_representation_shape_fn_(variable_representation_shape_fn) {}
string XlaContext::DebugString() { return "TLA JIT context"; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
void XlaContext::AddRetval(int retval_index, DataType type,
- const TensorShape& shape, const xla::XlaOp& handle) {
+ const xla::XlaOp& handle) {
VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
// Add the return value to the list being built up.
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
- XlaExpression e;
- e.set_handle(handle);
- retvals_[retval_index] = Retval{type, shape, e};
+ retvals_[retval_index].set_handle(handle);
}
Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
@@ -98,11 +94,13 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
if (retvals_.size() <= retval_index) {
retvals_.resize(retval_index + 1);
}
- Tensor value;
- TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
- XlaExpression e;
- e.set_constant_value(value);
- retvals_[retval_index] = Retval{dtype, value.shape(), e};
+ if (resolve_compile_time_constants_) {
+ Tensor value;
+ TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
+ retvals_[retval_index].set_constant_value(std::move(value));
+ } else {
+ retvals_[retval_index].set_handle(builder_->ConstantLiteral(literal));
+ }
return Status::OK();
}
@@ -119,9 +117,9 @@ Status XlaContext::CreateResource(
return Status::OK();
}
-TensorShape XlaContext::RepresentationShape(const TensorShape& shape,
- DataType type) const {
- return (*shape_representation_fn_)(shape, type);
+TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape,
+ DataType type) const {
+ return (*variable_representation_shape_fn_)(shape, type);
}
const xla::XlaComputation* XlaContext::GetOrCreateMax(const DataType type) {
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 3ad2b2e44e..1136ffe507 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -42,13 +42,11 @@ class XlaContext : public ResourceBase {
static XlaContext& Get(const OpKernelContext* ctx);
static XlaContext& Get(const XlaOpKernelContext* ctx);
- // Creates a new XlaContext. See the documentation on the class data fields
- // for descriptions of the arguments.
+ // Creates a new XlaContext.
XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
- bool is_entry_computation,
const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn);
+ variable_representation_shape_fn);
// Virtual method defined by ResourceBase.
string DebugString() override;
@@ -60,26 +58,14 @@ class XlaContext : public ResourceBase {
bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
- bool resolve_compile_time_constants() const {
- return resolve_compile_time_constants_;
- }
- bool is_entry_computation() const { return is_entry_computation_; }
-
const std::vector<XlaExpression>& args() const { return args_; }
void set_args(std::vector<XlaExpression> args);
- struct Retval {
- DataType type;
- TensorShape shape;
- // An XlaExpression representing the Retval's value.
- XlaExpression expression;
- };
- const std::vector<Retval>& retvals() { return retvals_; }
+ const std::vector<XlaExpression>& retvals() { return retvals_; }
// This is called by the Retval Op to associate a computed value
// with a specific return value of the subgraph.
- void AddRetval(int retval_index, DataType type, const TensorShape& shape,
- const xla::XlaOp& handle);
+ void AddRetval(int retval_index, DataType type, const xla::XlaOp& handle);
// As for Retval, but for return values that are compile-time constants.
Status AddConstRetval(int retval_index, DataType dtype,
@@ -100,9 +86,9 @@ class XlaContext : public ResourceBase {
}
// Returns the XLA shape to be used to represent a variable of TF `shape`
- // and `type`, or of an argument or return value of a top-level computation.
- TensorShape RepresentationShape(const TensorShape& shape,
- DataType type) const;
+ // and `type`.
+ TensorShape VariableRepresentationShape(const TensorShape& shape,
+ DataType type) const;
// Get an XLA lambda to compute Max. This is cached in the
// XlaContext since it may be used by multiple Ops. There is a
@@ -145,19 +131,15 @@ class XlaContext : public ResourceBase {
std::vector<XlaExpression> args_;
// Return values of the Tensorflow graph, indexed by _Retval index.
- std::vector<Retval> retvals_;
+ std::vector<XlaExpression> retvals_;
// Holds ownership of resources. The resources are not ordered.
std::vector<std::unique_ptr<XlaResource>> resources_;
- // Is this a top-level computation, or an inner computation (e.g., a while
- // body)?
- const bool is_entry_computation_;
-
// A function that describes how variable shapes should be represented
// in XLA. Variable values will be reshaped to this shape. Must be non-null.
const std::function<TensorShape(const TensorShape&, DataType)>*
- shape_representation_fn_;
+ variable_representation_shape_fn_;
// Cache of prebuilt computations indexed by their type.
using ComputationMap = std::map<DataType, xla::XlaComputation>;
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 76c68d81af..2b65f4d5d5 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -314,8 +314,8 @@ Status XlaOpKernelContext::ReadVariableInput(int index, DataType type,
}
XlaContext& xla_context = XlaContext::Get(context_);
- TensorShape representation_shape =
- xla_context.RepresentationShape(variable->shape(), variable->type());
+ TensorShape representation_shape = xla_context.VariableRepresentationShape(
+ variable->shape(), variable->type());
if (representation_shape == variable->shape()) {
*value = variable->value();
} else {
@@ -436,7 +436,7 @@ Status XlaOpKernelContext::AssignVariable(int input_index, DataType type,
XlaContext& xla_context = XlaContext::Get(context_);
TensorShape representation_shape =
- xla_context.RepresentationShape(shape, type);
+ xla_context.VariableRepresentationShape(shape, type);
if (shape != representation_shape) {
handle = builder()->Reshape(handle, representation_shape.dim_sizes());
}