aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc17
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h2
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index 1446401b19..6c23228976 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -41,11 +41,9 @@ namespace gpu {
// TODO(b/30467474) Once GPU infeed implementation settles, consider
// folding back the cpu and gpu infeed implementations into a generic
// one if possible.
-GpuTransferManager::GpuTransferManager()
- : GenericTransferManager(
- se::cuda::kCudaPlatformId,
- /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout)
- .getPointerSize(0 /* default address space */)) {}
+GpuTransferManager::GpuTransferManager(se::Platform::Id id,
+ unsigned pointer_size)
+ : GenericTransferManager(id, pointer_size) {}
Status GpuTransferManager::TransferLiteralToInfeed(
se::StreamExecutor* executor, const LiteralSlice& literal) {
@@ -179,13 +177,16 @@ Status GpuTransferManager::TransferLiteralFromOutfeed(
} // namespace gpu
} // namespace xla
-static std::unique_ptr<xla::TransferManager> CreateGpuTransferManager() {
- return xla::MakeUnique<xla::gpu::GpuTransferManager>();
+static std::unique_ptr<xla::TransferManager> CreateNVPTXTransferManager() {
+ return xla::MakeUnique<xla::gpu::GpuTransferManager>(
+ /*id=*/stream_executor::cuda::kCudaPlatformId,
+ /*pointer_size=*/llvm::DataLayout(xla::gpu::GpuCompiler::kDataLayout)
+ .getPointerSize(0 /* default address space */));
}
static bool InitModule() {
xla::TransferManager::RegisterTransferManager(
- stream_executor::cuda::kCudaPlatformId, &CreateGpuTransferManager);
+ stream_executor::cuda::kCudaPlatformId, &CreateNVPTXTransferManager);
return true;
}
static bool module_initialized = InitModule();
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
index 8122c9d8c3..dceeb9e2eb 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h
@@ -35,7 +35,7 @@ namespace gpu {
// handles GPU-specific infeed.
class GpuTransferManager : public GenericTransferManager {
public:
- GpuTransferManager();
+ GpuTransferManager(se::Platform::Id id, unsigned pointer_size);
~GpuTransferManager() override {}
Status TransferLiteralToInfeed(se::StreamExecutor* executor,