diff options
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc | 17 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h | 2 |
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 1446401b19..6c23228976 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -41,11 +41,9 @@ namespace gpu { // TODO(b/30467474) Once GPU infeed implementation settles, consider // folding back the cpu and gpu infeed implementations into a generic // one if possible. -GpuTransferManager::GpuTransferManager() - : GenericTransferManager( - se::cuda::kCudaPlatformId, - /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize(0 /* default address space */)) {} +GpuTransferManager::GpuTransferManager(se::Platform::Id id, + unsigned pointer_size) + : GenericTransferManager(id, pointer_size) {} Status GpuTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const LiteralSlice& literal) { @@ -179,13 +177,16 @@ Status GpuTransferManager::TransferLiteralFromOutfeed( } // namespace gpu } // namespace xla -static std::unique_ptr<xla::TransferManager> CreateGpuTransferManager() { - return xla::MakeUnique<xla::gpu::GpuTransferManager>(); +static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() { + return xla::MakeUnique<xla::gpu::GpuTransferManager>( + /*id=*/stream_executor::cuda::kCudaPlatformId, + /*pointer_size=*/llvm::DataLayout(xla::gpu::GpuCompiler::kDataLayout) + .getPointerSize(0 /* default address space */)); } static bool InitModule() { xla::TransferManager::RegisterTransferManager( - stream_executor::cuda::kCudaPlatformId, &CreateGpuTransferManager); + stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager); return true; } static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h index 8122c9d8c3..dceeb9e2eb 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h @@ -35,7 +35,7 @@ namespace gpu { // handles GPU-specific infeed. class GpuTransferManager : public GenericTransferManager { public: - GpuTransferManager(); + GpuTransferManager(se::Platform::Id id, unsigned pointer_size); ~GpuTransferManager() override {} Status TransferLiteralToInfeed(se::StreamExecutor* executor, |