aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit/xla_device.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/jit/xla_device.cc')
-rw-r--r--tensorflow/compiler/jit/xla_device.cc36
1 files changed, 30 insertions, 6 deletions
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index f13b46c532..ed007d603e 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device_context.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/core/common_runtime/device.h"
@@ -105,6 +106,25 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
return alloc_ptr;
}
+namespace {
+
+// Default PaddedShapeFn implementation that simply returns the unpadded
+// on-device shape. This is accurate for CPU and GPU devices that neither
+// transpose nor pad tensors.
+Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
+ const tensorflow::XlaTensor* xla_tensor =
+ tensorflow::XlaTensor::FromTensor(&tensor);
+ if (xla_tensor == nullptr) {
+ return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
+ }
+
+ const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
+ *shape = shaped_buffer.on_device_shape();
+ return Status::OK();
+}
+
+} // namespace
+
/* static */ Status XlaDevice::Create(
const string& platform_name, const string& device_name, int device_ordinal,
const string& jit_device_name, const SessionOptions& options,
@@ -112,7 +132,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
const XlaOpRegistry::DeviceRegistration& registration,
bool transfer_as_literal,
const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
- std::unique_ptr<XlaDevice>* device) {
+ const PaddedShapeFn& padded_shape_fn, std::unique_ptr<XlaDevice>* device) {
VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":"
<< device_ordinal;
@@ -133,17 +153,20 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
device->reset(new XlaDevice(
options, attrs, device_ordinal, DeviceType(jit_device_name),
- platform.ValueOrDie(), transfer_as_literal, shape_representation_fn));
+ platform.ValueOrDie(), transfer_as_literal, shape_representation_fn,
+ padded_shape_fn ? padded_shape_fn : DefaultPaddedShapeFn));
return Status::OK();
}
XlaDevice::Metadata::Metadata(
int device_ordinal, se::Platform* platform, const DeviceType& device_type,
- XlaCompiler::ShapeRepresentationFn shape_representation_fn)
+ XlaCompiler::ShapeRepresentationFn shape_representation_fn,
+ PaddedShapeFn padded_shape_fn)
: device_ordinal_(device_ordinal),
device_type_(device_type),
platform_(platform),
- shape_representation_fn_(std::move(shape_representation_fn)) {}
+ shape_representation_fn_(std::move(shape_representation_fn)),
+ padded_shape_fn_(std::move(padded_shape_fn)) {}
int XlaDevice::Metadata::device_ordinal() const { return device_ordinal_; }
@@ -178,10 +201,11 @@ XlaDevice::XlaDevice(
const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
se::Platform* platform, bool transfer_as_literal,
- const XlaCompiler::ShapeRepresentationFn& shape_representation_fn)
+ const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
+ const PaddedShapeFn& padded_shape_fn)
: LocalDevice(options, attrs),
xla_metadata_(device_ordinal, platform, jit_device_name,
- shape_representation_fn),
+ shape_representation_fn, padded_shape_fn),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),