aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/generic_transfer_manager.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/generic_transfer_manager.cc')
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index a99e2b7794..78dc0ad4fc 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -38,7 +38,14 @@ namespace xla {
GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id,
size_t pointer_size)
- : platform_id_(platform_id), pointer_size_(pointer_size) {}
+ : platform_id_(platform_id), pointer_size_(pointer_size) {
+ // We currently only support kHostPlatformId for CPU, kCudaPlatformId for
+ // GPU and kInterpreterPlatformId for Interpreter. Before supporting other
+ // platforms, we need to test this transfer manager on them.
+ CHECK(platform_id_ == se::host::kHostPlatformId ||
+ platform_id_ == se::interpreter::kInterpreterPlatformId ||
+ platform_id_ == se::cuda::kCudaPlatformId);
+}
se::Platform::Id GenericTransferManager::PlatformId() const {
return platform_id_;