aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/local_service.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/local_service.cc')
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 06f43bd3cb..d4d35da9d6 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -68,6 +68,26 @@ LocalService::LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend)
: Service(options, std::move(execute_backend)) {}
+namespace {
+// Returns the space required to allocate a shape. If
+// allocate_space_for_deep_copy the space includes all sub-buffers of
+// a tuple.
+int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy,
+ TransferManager* transfer_manager) {
+ int64 size = 0;
+ // TODO(b/33492279) remove once no devices represent result tuples as
+ // contiguous buffers.
+ if (allocate_space_for_deep_copy) {
+ ShapeUtil::ForEachSubshape(
+ shape, [&size, transfer_manager](const Shape& subshape,
+ const ShapeIndex& /*index*/) {
+ size += transfer_manager->GetByteSizeRequirement(subshape);
+ });
+ }
+ return size;
+}
+} // namespace
+
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
const ComputationHandle& computation,
const tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,