diff options
author | 2018-07-13 19:44:08 -0700 | |
---|---|---|
committer | 2018-07-13 19:44:15 -0700 | |
commit | a25dbac5ee56e03f4b475c443aa58cf2f7160f52 (patch) | |
tree | 4c0b7981a9f88ed3085ed9b233082dfa171c16a0 | |
parent | af69280cc6aa2d4acdb3e8522523d3853878bf3f (diff) | |
parent | 9610dfb8b7a6f0d3c57574baf701868987fef3b1 (diff) |
Merge pull request #20759 from ROCmSoftwarePlatform:upstream-staging-xla-gpu-transfer-manager
PiperOrigin-RevId: 204562235
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h | 2 |
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 1446401b19..6c23228976 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -41,11 +41,9 @@ namespace gpu { // TODO(b/30467474) Once GPU infeed implementation settles, consider // folding back the cpu and gpu infeed implementations into a generic // one if possible. -GpuTransferManager::GpuTransferManager() - : GenericTransferManager( - se::cuda::kCudaPlatformId, - /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize(0 /* default address space */)) {} +GpuTransferManager::GpuTransferManager(se::Platform::Id id, + unsigned pointer_size) + : GenericTransferManager(id, pointer_size) {} Status GpuTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const LiteralSlice& literal) { @@ -179,13 +177,16 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( } // namespace gpu } // namespace xla -static std::unique_ptr<xla::TransferManager> CreateGpuTransferManager() { - return xla::MakeUnique<xla::gpu::GpuTransferManager>(); +static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() { + return xla::MakeUnique<xla::gpu::GpuTransferManager>( + /*id=*/stream_executor::cuda::kCudaPlatformId, + /*pointer_size=*/llvm::DataLayout(xla::gpu::GpuCompiler::kDataLayout) + .getPointerSize(0 /* default address space */)); } static bool InitModule() { xla::TransferManager::RegisterTransferManager( - stream_executor::cuda::kCudaPlatformId, &CreateGpuTransferManager); + stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 8122c9d8c3..dceeb9e2eb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -35,7 +35,7 @@ namespace gpu { // handles GPU-specific infeed. class GpuTransferManager : public GenericTransferManager { public: - GpuTransferManager(); + GpuTransferManager(se::Platform::Id id, unsigned pointer_size); ~GpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, |