diff options
Diffstat (limited to 'tensorflow/compiler/jit/xla_device.cc')
-rw-r--r-- | tensorflow/compiler/jit/xla_device.cc | 36 |
1 files changed, 30 insertions, 6 deletions
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index f13b46c532..ed007d603e 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/jit/xla_device_ops.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" @@ -105,6 +106,25 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( return alloc_ptr; } +namespace { + +// Default PaddedShapeFn implementation that simply returns the unpadded +// on-device shape. This is accurate for CPU and GPU devices that neither +// transpose nor pad tensors. +Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) { + const tensorflow::XlaTensor* xla_tensor = + tensorflow::XlaTensor::FromTensor(&tensor); + if (xla_tensor == nullptr) { + return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape); + } + + const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer(); + *shape = shaped_buffer.on_device_shape(); + return Status::OK(); +} + +} // namespace + /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, @@ -112,7 +132,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( const XlaOpRegistry::DeviceRegistration& registration, bool transfer_as_literal, const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, - std::unique_ptr<XlaDevice>* device) { + const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; @@ -133,17 +153,20 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( device->reset(new XlaDevice( options, attrs, device_ordinal, DeviceType(jit_device_name), - platform.ValueOrDie(), transfer_as_literal, shape_representation_fn)); + platform.ValueOrDie(), transfer_as_literal, shape_representation_fn, + padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn)); return Status::OK(); } XlaDevice::Metadata::Metadata( int device_ordinal, se::Platform* platform, const DeviceType& device_type, - XlaCompiler::ShapeRepresentationFn shape_representation_fn) + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + PaddedShapeFn padded_shape_fn) : device_ordinal_(device_ordinal), device_type_(device_type), platform_(platform), - shape_representation_fn_(std::move(shape_representation_fn)) {} + shape_representation_fn_(std::move(shape_representation_fn)), + padded_shape_fn_(std::move(padded_shape_fn)) {} int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; } @@ -178,10 +201,11 @@ XlaDevice::XlaDevice( const SessionOptions& options, const DeviceAttributes& attrs, int device_ordinal, const DeviceType& jit_device_name, se::Platform* platform, bool transfer_as_literal, - const XlaCompiler::ShapeRepresentationFn& shape_representation_fn) + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + const PaddedShapeFn& padded_shape_fn) : LocalDevice(options, attrs), xla_metadata_(device_ordinal, platform, jit_device_name, - shape_representation_fn), + shape_representation_fn, padded_shape_fn), device_ordinal_(device_ordinal), jit_device_name_(jit_device_name), xla_allocator_(nullptr), |