diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/generic_transfer_manager.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/generic_transfer_manager.cc | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index a99e2b7794..78dc0ad4fc 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -38,7 +38,14 @@ namespace xla { GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id, size_t pointer_size) - : platform_id_(platform_id), pointer_size_(pointer_size) {} + : platform_id_(platform_id), pointer_size_(pointer_size) { + // We currently only support kHostPlatformId for CPU, kCudaPlatformId for + // GPU and kInterpreterPlatformId for Interpreter. Before supporting other + // platforms, we need to test this transfer manager on them. + CHECK(platform_id_ == se::host::kHostPlatformId || + platform_id_ == se::interpreter::kInterpreterPlatformId || + platform_id_ == se::cuda::kCudaPlatformId); +} se::Platform::Id GenericTransferManager::PlatformId() const { return platform_id_; |