aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-13 19:44:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-13 19:44:15 -0700
commita25dbac5ee56e03f4b475c443aa58cf2f7160f52 (patch)
tree4c0b7981a9f88ed3085ed9b233082dfa171c16a0
parentaf69280cc6aa2d4acdb3e8522523d3853878bf3f (diff)
parent9610dfb8b7a6f0d3c57574baf701868987fef3b1 (diff)
Merge pull request #20759 from ROCmSoftwarePlatform:upstream-staging-xla-gpu-transfer-manager
PiperOrigin-RevId: 204562235
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h2
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index 1446401b19..6c23228976 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -41,11 +41,9 @@ namespace gpu {
// TODO(b/30467474) Once GPU infeed implementation settles, consider
// folding back the cpu and gpu infeed implementations into a generic
// one if possible.
-GpuTransferManager::GpuTransferManager()
- : GenericTransferManager(
- se::cuda::kCudaPlatformId,
- /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout)
- .getPointerSize(0 /* default address space */)) {}
+GpuTransferManager::GpuTransferManager(se::Platform::Id id,
+ unsigned pointer_size)
+ : GenericTransferManager(id, pointer_size) {}
Status GpuTransferManager::TransferLiteralToInfeed(
se::StreamExecutor* executor, const LiteralSlice& literal) {
@@ -179,13 +177,16 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
} // namespace gpu
} // namespace xla
-static std::unique_ptr<xla::TransferManager> CreateGpuTransferManager() {
- return xla::MakeUnique<xla::gpu::GpuTransferManager>();
+static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() {
+ return xla::MakeUnique<xla::gpu::GpuTransferManager>(
+ /*id=*/stream_executor::cuda::kCudaPlatformId,
+ /*pointer_size=*/llvm::DataLayout(xla::gpu::GpuCompiler::kDataLayout)
+ .getPointerSize(0 /* default address space */));
}
static bool InitModule() {
xla::TransferManager::RegisterTransferManager(
- stream_executor::cuda::kCudaPlatformId, &CreateGpuTransferManager);
+ stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager);
return true;
}
static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index 8122c9d8c3..dceeb9e2eb 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -35,7 +35,7 @@ namespace gpu {
// handles GPU-specific infeed.
class GpuTransferManager : public GenericTransferManager {
public:
- GpuTransferManager();
+ GpuTransferManager(se::Platform::Id id, unsigned pointer_size);
~GpuTransferManager() override {}
Status TransferLiteralToInfeed(se::StreamExecutor* executor,