diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/local_service.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/local_service.cc | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 06f43bd3cb..d4d35da9d6 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -68,6 +68,26 @@ LocalService::LocalService(const ServiceOptions& options, std::unique_ptr<Backend> execute_backend) : Service(options, std::move(execute_backend)) {} +namespace { +// Returns the space required to allocate a shape. If +// allocate_space_for_deep_copy the space includes all sub-buffers of +// a tuple. +int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, + TransferManager* transfer_manager) { + int64 size = 0; + // TODO(b/33492279) remove once no devices represent result tuples as + // contiguous buffers. + if (allocate_space_for_deep_copy) { + ShapeUtil::ForEachSubshape( + shape, [&size, transfer_manager](const Shape& subshape, + const ShapeIndex& /*index*/) { + size += transfer_manager->GetByteSizeRequirement(subshape); + }); + } + return size; +} +} // namespace + StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts, |