aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-03-07 08:52:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-07 09:22:34 -0800
commit738143e6cd7f9eba0a0e77b44c6cc5ae4e1781ad (patch)
tree452a8a2f41e5061140a7893fbee6764674bc62d3 /tensorflow/compiler/xla
parentceb7fc1b64611b09a1d03490f5f0a9c155a93137 (diff)
[TF:XLA] Remove support for client-allocated result buffers.
This code path is unused; Tensorflow ended up settling on having XLA allocate result buffers using Tensorflow's allocator. Remove it to reduce the proliferation of ExecuteXYZ() methods. Change: 149423775
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc45
-rw-r--r--tensorflow/compiler/xla/client/local_client.h18
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc102
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc8
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h6
-rw-r--r--tensorflow/compiler/xla/service/executable.h8
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc106
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h6
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc102
-rw-r--r--tensorflow/compiler/xla/service/local_service.h32
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc15
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h10
13 files changed, 15 insertions, 449 deletions
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index c51dc05af8..d515946781 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -184,43 +184,6 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run(
/*hlo_execution_profile=*/nullptr);
}
-tensorflow::Status LocalExecutable::Run(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutableRunOptions& options, ShapedBuffer* result) {
- const ComputationLayout& computation_layout =
- executable_->module_config().entry_computation_layout();
- TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options));
-
- if (!computation_layout.result_layout().MatchesLayoutInShape(
- result->shape())) {
- return InvalidArgument(
- "result buffer does not match shape or layout of computation result: "
- "expected %s, got %s",
- ShapeUtil::HumanString(computation_layout.result_layout().shape())
- .c_str(),
- ShapeUtil::HumanString(result->shape()).c_str());
- }
-
- ExecutableRunOptions actual_options = options;
- Backend::StreamPtr stream;
- if (options.stream() == nullptr) {
- TF_ASSIGN_OR_RETURN(
- stream, BorrowStreamForDevice(options.device_ordinal(), backend_));
- actual_options.set_stream(stream.get());
- }
- if (options.allocator() == nullptr) {
- actual_options.set_allocator(backend_->memory_allocator());
- }
- ServiceExecutableRunOptions service_options(actual_options,
- backend_->StreamBorrower());
-
- if (executable_->dumping()) {
- return Unimplemented("dumping execution not supported on this path");
- }
- return executable_->ExecuteOnStream(&service_options, arguments, result,
- /*hlo_execution_profile=*/nullptr);
-}
-
StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::ExecuteAndDump(
const ServiceExecutableRunOptions* run_options,
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
@@ -287,14 +250,6 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalClient::ExecuteLocally(
options);
}
-tensorflow::Status LocalClient::ExecuteLocally(
- const Computation& computation,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options, ShapedBuffer* result) {
- return local_service_->ExecuteLocally(computation.handle(), arguments,
- options, result);
-}
-
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
LocalClient::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h
index 484a6bc47c..b216ea6097 100644
--- a/tensorflow/compiler/xla/client/local_client.h
+++ b/tensorflow/compiler/xla/client/local_client.h
@@ -83,12 +83,6 @@ class LocalExecutable {
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableRunOptions& options);
- // Overload which places the computation result in the given preallocated
- // buffer.
- tensorflow::Status Run(
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const ExecutableRunOptions& options, ShapedBuffer* result);
-
// Return the layout (contained in a shape) of the result produced by the
// computation.
const Shape& result_layout() const {
@@ -200,18 +194,6 @@ class LocalClient : public Client {
const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const LocalExecuteOptions& options);
- // Overload of ExecuteLocally which writes the result into the given
- // ShapedBuffer "result". Result is const because the ShapedBuffer data
- // structure itself is not modified, only the buffers in device memory to
- // which it refers.
- //
- // TODO(b/31220873): Remove ExecuteLocally methods. The path forward is to use
- // Compile and run the returned LocalExecutable.
- tensorflow::Status ExecuteLocally(
- const Computation& computation,
- const tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options, ShapedBuffer* result);
-
// Build and return a LocalExecutable object. The executable is compiled using
// the given argument layouts and options.
StatusOr<std::unique_ptr<LocalExecutable>> Compile(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 9392241dbe..88283e6010 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -360,108 +360,6 @@ StatusOr<std::unique_ptr<ShapedBuffer>> CpuExecutable::ExecuteOnStream(
return std::move(result_buffer);
}
-Status CpuExecutable::ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) {
- // Every array element in the result of the computation must be unambiguously
- // produced by a single instruction.
- // This ensures that the buffers inside result_buffer can be assigned without
- // conflict to the respective instructions because there is a one-to-one
- // correspondence between hlo instructions and array buffers in the result.
- if (GetRootPointsToSet().IsAmbiguous()) {
- return Unimplemented(
- "Points-to set of root instruction is ambiguous or not distinct");
- }
- if (!ShapeUtil::Compatible(result_buffer->shape(), result_shape())) {
- return InvalidArgument(
- "Result buffer shape %s is incompatible with result shape %s",
- ShapeUtil::HumanString(result_buffer->shape()).c_str(),
- ShapeUtil::HumanString(result_shape()).c_str());
- }
-
- se::Stream* stream = run_options->stream();
- DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
-
- // If two tuple elements point to the same buffer, one of the results in the
- // result buffer is considered the canonical location while the other result
- // points to it (instead of, say, making a copy of the result).
- // buffer_index_to_shape_index maps a buffer index to its canonical location
- // in the result buffer.
- std::unordered_map<BufferAllocation::Index, size_t>
- buffer_index_to_shape_index;
-
- // Copy values from result_buffer to the index in "buffers". These buffers
- // will not be allocated in the call to AllocateBuffers.
- std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
- TF_RETURN_IF_ERROR(
- result_buffer->mutable_shape_index_to_buffer_entry()
- ->ForEachMutableElement(
- [&buffers, &buffers_in_result, &buffer_index_to_shape_index,
- result_buffer, this](const ShapeIndex& index, bool is_leaf,
- size_t* buffer_entry) {
- if (is_leaf) {
- const std::vector<const LogicalBuffer*>& sources =
- this->GetRootPointsToSet().element(index);
- // The points to set is unambiguous so the set should be a
- // singleton.
- CHECK_EQ(1, sources.size());
- const LogicalBuffer* buffer_source = sources[0];
- HloInstruction* src = buffer_source->instruction();
-
- // The source for this result buffer can be a nested buffer
- // such as a tuple element.
-
- // The source instruction should have a non-parameter buffer
- // assigned.
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
- this->assignment_->GetUniqueSlice(
- src, buffer_source->index()));
- CHECK(!slice.allocation()->is_entry_computation_parameter());
-
- const BufferAllocation::Index buffer_index = slice.index();
- auto insert_result = buffer_index_to_shape_index.emplace(
- buffer_index, *buffer_entry);
- if (insert_result.second) {
- // The points-to set is distinct so this buffer should not
- // have
- // been assigned in a previous invocation of this lambda.
- perftools::gputools::DeviceMemoryBase memory_base =
- result_buffer->buffer(index);
- CHECK(!memory_base.is_null());
- CHECK(buffers[buffer_index].is_null());
- buffers[buffer_index] = memory_base;
- buffers_in_result[buffer_index] = true;
- } else {
- // Record the fact that this tuple element is identical to
- // some
- // prior result.
- *buffer_entry = insert_result.first->second;
- }
- }
- return Status::OK();
- }));
-
- TF_RETURN_IF_ERROR(AllocateBuffers(
- memory_allocator, stream->parent()->device_ordinal(), &buffers));
- TF_RETURN_IF_ERROR(ExecuteComputeFunction(
- &run_options->run_options(), arguments, buffers, hlo_execution_profile));
-
- // Free all buffers not in the result.
- for (size_t i = 0; i < buffers.size(); ++i) {
- se::DeviceMemoryBase alloc = buffers[i];
- if (!buffers_in_result[i] && !alloc.is_null()) {
- VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
- << alloc.opaque() << "]";
- TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
- stream->parent()->device_ordinal(), &alloc));
- }
- }
-
- return Status::OK();
-}
-
StatusOr<perftools::gputools::DeviceMemoryBase>
CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 0963f2d470..b04b4e8dd1 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -68,12 +68,6 @@ class CpuExecutable : public Executable {
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) override;
- Status ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result_buffer,
- HloExecutionProfile* hlo_execution_profile) override;
-
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
index 562a09cad2..bab3440e2c 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
@@ -349,14 +349,6 @@ StatusOr<std::unique_ptr<ShapedBuffer>> ParallelCpuExecutable::ExecuteOnStream(
"ParallelCpuExecutable not supported yet with LocalService execution");
}
-Status ParallelCpuExecutable::ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) {
- return Unimplemented(
- "preallocated result buffer not supported with ParallelCpuExecutable");
-}
-
StatusOr<perftools::gputools::DeviceMemoryBase>
ParallelCpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
index d274ab1667..7ce059bb1d 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
@@ -71,12 +71,6 @@ class ParallelCpuExecutable : public Executable {
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) override;
- Status ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result_buffer,
- HloExecutionProfile* hlo_execution_profile) override;
-
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index 6ca9c82914..a697d8cde2 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -71,14 +71,6 @@ class Executable {
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) = 0;
- // Overload of which writes the result into a pre-allocated buffer
- // (result_buffer).
- virtual Status ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result_buffer,
- HloExecutionProfile* hlo_execution_profile) = 0;
-
// Same as ExecuteOnStream(), but this call is non-blocking and returns as
// soon as all of the operations are enqueued for launch on the stream.
virtual StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 72a1db90a4..b16461b113 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -337,112 +337,6 @@ StatusOr<std::unique_ptr<ShapedBuffer>> GpuExecutable::ExecuteOnStream(
return std::move(shaped_buffer);
}
-Status GpuExecutable::ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result_buffer, HloExecutionProfile* hlo_execution_profile) {
- se::Stream* stream = run_options->stream();
- DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- // This ExecuteOnStream overload should only be called by the LocalService
- // which sets has_hybrid_result to true.
- TF_RET_CHECK(module_config().has_hybrid_result());
-
- // Every array element in the result of the computation must be unambiguously
- // produced by a single instruction.
- // This ensures that the buffers inside result_buffer can be assigned without
- // conflict to the respective instructions because there is a one-to-one
- // correspondence between hlo instructions and array buffers in the result.
- if (GetRootPointsToSet().IsAmbiguous()) {
- return Unimplemented(
- "Points-to set of root instruction is ambiguous or not distinct");
- }
-
- DCHECK(ShapeUtil::Compatible(result_buffer->shape(), result_shape()));
-
- BufferAllocations::Builder buffer_allocations_builder;
- for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
- ++i) {
- const BufferAllocation& allocation = assignment_->GetAllocation(i);
- if (allocation.is_entry_computation_parameter()) {
- auto param_no = allocation.parameter_number();
- if (ShapeUtil::IsTuple(arguments[param_no]->shape())) {
- return Unimplemented("Tuple ShapedBuffer arguments not supported");
- }
- buffer_allocations_builder.RegisterBuffer(
- i, arguments[param_no]->buffer(/*index=*/{}));
- }
- }
-
- // If two tuple elements point to the same buffer, one of the results in the
- // result buffer is considered the canonical location while the other result
- // points to it (instead of, say, making a copy of the result).
- // buffer_index_to_shape_index maps a buffer index to its canonical location
- // in the result buffer.
- std::unordered_map<BufferAllocation::Index, size_t>
- buffer_index_to_shape_index;
-
- // Register DeviceMemoryBase values in result_buffer to their corresponding
- // buffer indices. These buffers will not be allocated in the call to
- // BufferAllocationsBuilder::Build.
- std::set<se::DeviceMemoryBase> buffers_in_result;
- TF_RETURN_IF_ERROR(
- result_buffer->mutable_shape_index_to_buffer_entry()
- ->ForEachMutableElement(
- [&buffer_allocations_builder, &buffers_in_result,
- &buffer_index_to_shape_index, result_buffer, this](
- const ShapeIndex& index, bool is_leaf, size_t* buffer_entry) {
- if (is_leaf) {
- const std::vector<const LogicalBuffer*>& sources =
- this->GetRootPointsToSet().element(index);
- // The points to set is unambiguous so the set should be a
- // singleton. That is, we know exactly which instruction
- // produced the array at this element.
- CHECK_EQ(1, sources.size());
- auto src_hlo = sources[0]->instruction();
-
- VLOG(4) << "Looking at: " << sources[0];
-
- // The source instruction should have a non-parameter buffer
- // assigned.
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
- this->assignment_->GetUniqueSlice(
- src_hlo, sources[0]->index()));
- CHECK(!slice.allocation()->is_entry_computation_parameter());
-
- auto insert_result = buffer_index_to_shape_index.emplace(
- slice.index(), *buffer_entry);
- if (insert_result.second) {
- // The points-to set is distinct so this buffer should not
- // have been assigned in a previous invocation of this
- // lambda.
- perftools::gputools::DeviceMemoryBase memory_base =
- result_buffer->buffer(index);
- CHECK(!memory_base.is_null());
- buffer_allocations_builder.RegisterBuffer(slice.index(),
- memory_base);
- buffers_in_result.insert(memory_base);
- } else {
- // Record the fact that this tuple element is identical to
- // some
- // prior result.
- *buffer_entry = insert_result.first->second;
- }
- }
- return Status::OK();
- }));
-
- se::StreamExecutor* executor = stream->parent();
- auto device_ordinal = executor->device_ordinal();
- TF_ASSIGN_OR_RETURN(auto buffer_allocations,
- buffer_allocations_builder.Build(
- *assignment_, device_ordinal, memory_allocator));
-
- TF_RETURN_IF_ERROR(
- ExecuteThunks(run_options, *buffer_allocations, hlo_execution_profile));
-
- return buffer_allocations->TearDown(buffers_in_result, *assignment_);
-}
-
StatusOr<se::DeviceMemoryBase> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 5254571d06..d65394ce30 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -76,12 +76,6 @@ class GpuExecutable : public Executable {
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) override;
- Status ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result_buffer,
- HloExecutionProfile* hlo_execution_profile) override;
-
StatusOr<perftools::gputools::DeviceMemoryBase> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index d736545345..49be075d24 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -186,27 +186,6 @@ StatusOr<GlobalDataHandle> LocalService::AllocateBufferOnDevice(
allocation_size));
}
-StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocally(
- const ComputationHandle& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options) {
- return ExecuteLocallyInternal(computation, arguments, options,
- /*preallocated_result_buffer=*/nullptr);
-}
-
-tensorflow::Status LocalService::ExecuteLocally(
- const ComputationHandle& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options, ShapedBuffer* result_buffer) {
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<ShapedBuffer> null_buffer,
- ExecuteLocallyInternal(computation, arguments, options, result_buffer));
- // Because the result is written into result_buffer, a null ShapedBuffer
- // pointer should have been returned.
- CHECK_EQ(nullptr, null_buffer.get());
- return tensorflow::Status::OK();
-}
-
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
LocalService::CompileAheadOfTime(
const tensorflow::gtl::ArraySlice<AheadOfTimeComputationInstance>
@@ -256,8 +235,7 @@ LocalService::CompileAheadOfTime(
tensorflow::Status LocalService::ValidateExecuteOptions(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> argument_layouts,
- const LocalExecuteOptions& options,
- const ShapedBuffer* preallocated_result_buffer) {
+ const LocalExecuteOptions& options) {
if (argument_layouts.size() != program_shape.parameters_size()) {
return InvalidArgument(
"invalid number of arguments for computation: expected %d, got %zu",
@@ -314,46 +292,13 @@ tensorflow::Status LocalService::ValidateExecuteOptions(
execute_backend_->platform()->Name().c_str());
}
- if (preallocated_result_buffer != nullptr) {
- if (options.result_layout()) {
- return InvalidArgument(
- "cannot set both result ShapedBuffer and result layout; the result "
- "ShapedBuffer determines the result layout");
- }
- if (!ShapeUtil::Compatible(preallocated_result_buffer->shape(),
- program_shape.result())) {
- return InvalidArgument(
- "result ShapedBuffer of shape %s not compatible with computation "
- "result shape %s",
- ShapeUtil::HumanString(preallocated_result_buffer->shape()).c_str(),
- ShapeUtil::HumanString(program_shape.result()).c_str());
- }
- }
- if (options.result_layout()) {
- TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout(*options.result_layout(),
- program_shape.result()));
- }
-
- // Check that all argument layouts are valid and the right shape.
- for (int i = 0; i < argument_layouts.size(); ++i) {
- const Shape& argument_shape = *argument_layouts[i];
- TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape));
- if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
- return InvalidArgument(
- "invalid argument shape for argument %d, expected %s, got %s", i,
- ShapeUtil::HumanString(program_shape.parameters(i)).c_str(),
- ShapeUtil::HumanString(argument_shape).c_str());
- }
- }
-
return tensorflow::Status::OK();
}
-StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
+StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocally(
const ComputationHandle& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options,
- ShapedBuffer* preallocated_result_buffer) {
+ const LocalExecuteOptions& options) {
TF_ASSIGN_OR_RETURN(UserComputation * user_computation,
computation_tracker_.Resolve(computation));
VersionedComputationHandle versioned_handle =
@@ -388,8 +333,8 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
argument_layouts[i] = &argument->shape();
}
- TF_RETURN_IF_ERROR(ValidateExecuteOptions(
- *program_shape, argument_layouts, options, preallocated_result_buffer));
+ TF_RETURN_IF_ERROR(
+ ValidateExecuteOptions(*program_shape, argument_layouts, options));
// Construct computation layout from the argument layouts.
auto module_config = MakeUnique<HloModuleConfig>(*program_shape);
@@ -411,10 +356,6 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
TF_RETURN_IF_ERROR(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
*options.result_layout()));
- } else if (preallocated_result_buffer != nullptr) {
- TF_RETURN_IF_ERROR(
- computation_layout->mutable_result_layout()->CopyLayoutFromShape(
- preallocated_result_buffer->shape()));
} else {
computation_layout->mutable_result_layout()->SetToDefaultLayout();
}
@@ -454,30 +395,15 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
argument_buffers, execute_backend_.get(),
service_run_options.stream()->parent(), profile));
- if (preallocated_result_buffer == nullptr) {
- return Service::ExecuteOnStreamWrapper<
- StatusOr<std::unique_ptr<ShapedBuffer>>>(
- executable.get(), &service_run_options, profile,
- [&arguments](Executable* executable,
- const ServiceExecutableRunOptions* run_options,
- HloExecutionProfile* hlo_execution_profile) {
- return executable->ExecuteOnStream(run_options, arguments,
- hlo_execution_profile);
- });
- } else {
- TF_RETURN_IF_ERROR(Service::ExecuteOnStreamWrapper<tensorflow::Status>(
- executable.get(), &service_run_options, profile,
- [&arguments, preallocated_result_buffer](
- Executable* executable,
- const ServiceExecutableRunOptions* run_options,
- HloExecutionProfile* hlo_execution_profile) {
- return executable->ExecuteOnStream(run_options, arguments,
- preallocated_result_buffer,
- hlo_execution_profile);
- }));
- // To satisfy the return value type, Return a null ShapedBuffer pointer.
- return std::unique_ptr<ShapedBuffer>();
- }
+ return Service::ExecuteOnStreamWrapper<
+ StatusOr<std::unique_ptr<ShapedBuffer>>>(
+ executable.get(), &service_run_options, profile,
+ [&arguments](Executable* executable,
+ const ServiceExecutableRunOptions* run_options,
+ HloExecutionProfile* hlo_execution_profile) {
+ return executable->ExecuteOnStream(run_options, arguments,
+ hlo_execution_profile);
+ });
}
StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h
index 9fe0d5993b..b0a5955e25 100644
--- a/tensorflow/compiler/xla/service/local_service.h
+++ b/tensorflow/compiler/xla/service/local_service.h
@@ -120,25 +120,6 @@ class LocalService : public Service {
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const LocalExecuteOptions& options);
- // Overload which writes the result into the given ShapedBuffer "result".
- // Due to aliasing, not all buffers which comprise "result" may be utilized
- // in the computation and thus be uninitialized. The |ShapedBuffer::buffer|
- // or |ShapedBuffer::mutable_buffer| methods should be used to map an index to
- // the initialized buffer.
- //
- // For example:
- // Let 'result' be a ShapedBuffer holding a tuple with the same element,
- // 'x', twice: (x, x). It is incorrect to assume that the second buffer
- // which comprises 'result' is initialized. Instead, a mapping has been
- // added to 'result' which can be used to recover the correct buffer.
- // In this case, result->buffer({0}) should be used to extract the address of
- // the first tuple element while result->buffer({1}) should be used for the
- // second.
- tensorflow::Status ExecuteLocally(
- const ComputationHandle& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options, ShapedBuffer* result_buffer);
-
// A description of a computation to compile using CompileAheadOfTime.
struct AheadOfTimeComputationInstance {
ComputationHandle computation;
@@ -169,23 +150,12 @@ class LocalService : public Service {
LocalService(const LocalService&) = delete;
void operator=(const LocalService&) = delete;
- // Internal helper for executing a computation. If result_buffer is null then
- // the result is returned as a ShapedBuffer. If result_buffer is non-null then
- // the result is written into result_buffer and a null ShapedBuffer pointer is
- // returned.
- StatusOr<std::unique_ptr<ShapedBuffer>> ExecuteLocallyInternal(
- const ComputationHandle& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options,
- ShapedBuffer* preallocated_result_buffer);
-
// Validates the given options and argument layouts and returns an appropriate
// error code.
tensorflow::Status ValidateExecuteOptions(
const ProgramShape& program_shape,
tensorflow::gtl::ArraySlice<const Shape*> arguments,
- const LocalExecuteOptions& options,
- const ShapedBuffer* preallocated_result_buffer);
+ const LocalExecuteOptions& options);
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index 5c32ed8895..aaddb92968 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -202,19 +202,4 @@ std::unique_ptr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
options.allocator());
}
-void LocalClientTestBase::ExecuteLocally(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result) {
- ExecuteLocally(computation, arguments, DefaultLocalExecuteOptions(), result);
-}
-
-void LocalClientTestBase::ExecuteLocally(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options, ShapedBuffer* result) {
- ASSERT_IS_OK(
- local_client_->ExecuteLocally(computation, arguments, options, result));
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index 62916d50e3..7de3faaba6 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -111,16 +111,6 @@ class LocalClientTestBase : public ::testing::Test {
// as the allocator.
LocalExecuteOptions DefaultLocalExecuteOptions() const;
- // Overloads which write result into the given buffer.
- void ExecuteLocally(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- ShapedBuffer* result);
- void ExecuteLocally(
- const Computation& computation,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- const LocalExecuteOptions& options, ShapedBuffer* result);
-
// Convert a ShapedBuffer into a ScopedShaped buffer so that all buffers are
// deallocated when the object is destructed.
std::unique_ptr<ScopedShapedBuffer> ShapedBufferToScopedShapedBuffer(