aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-15 11:21:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-15 11:43:59 -0800
commit61197393ab39929e945e9adf1378659a5c2bbab1 (patch)
treeabc897ac1b3088adc16d8fa14603948b6ee2f72f
parentb875419c9455e6d1d1b3e757fa159011487da2bd (diff)
[XLA] Use `Pool<se::Stream>` as stream cache in backend, and use smart pointers rather than explicitly release acquired streams
Change: 147620836
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc50
-rw-r--r--tensorflow/compiler/xla/service/BUILD21
-rw-r--r--tensorflow/compiler/xla/service/backend.cc40
-rw-r--r--tensorflow/compiler/xla/service/backend.h32
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.cc17
-rw-r--r--tensorflow/compiler/xla/service/execution_tracker.h18
-rw-r--r--tensorflow/compiler/xla/service/local_service.cc16
-rw-r--r--tensorflow/compiler/xla/service/pool.h84
-rw-r--r--tensorflow/compiler/xla/service/pool_test.cc40
-rw-r--r--tensorflow/compiler/xla/service/service.cc45
10 files changed, 213 insertions, 150 deletions
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index 384aae867b..c0759f20ae 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -67,35 +67,15 @@ bool ExecutableBuildOptions::has_hybrid_result() const {
}
namespace {
-
-// Convenience class which holds an acquired stream from the backend and
-// automatically releases it when destructed.
-class StreamManager {
- public:
- static StatusOr<std::unique_ptr<StreamManager>> AcquireStream(
- Backend* backend, int device_ordinal) {
- TF_ASSIGN_OR_RETURN(
- se::StreamExecutor * executor,
- backend->stream_executor(device_ordinal == -1
- ? backend->default_device_ordinal()
- : device_ordinal));
- TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
- backend->AcquireStream(executor));
- return WrapUnique(new StreamManager(backend, std::move(stream)));
+StatusOr<Backend::StreamPtr> BorrowStreamForDevice(int device_ordinal,
+ Backend* backend) {
+ if (device_ordinal < 0) {
+ device_ordinal = backend->default_device_ordinal();
}
-
- ~StreamManager() { backend_->ReleaseStream(std::move(stream_)); }
-
- se::Stream* stream() const { return stream_.get(); }
-
- private:
- StreamManager(Backend* backend, std::unique_ptr<se::Stream> stream)
- : backend_(backend), stream_(std::move(stream)) {}
-
- Backend* backend_;
- std::unique_ptr<se::Stream> stream_;
-};
-
+ TF_ASSIGN_OR_RETURN(se::StreamExecutor * exec,
+ backend->stream_executor(device_ordinal));
+ return backend->BorrowStream(exec);
+}
} // namespace
LocalExecutable::LocalExecutable(std::unique_ptr<Executable> executable,
@@ -186,12 +166,11 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalExecutable::Run(
TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options));
ExecutableRunOptions actual_options = options;
- std::unique_ptr<StreamManager> acquired_stream;
+ Backend::StreamPtr stream;
if (options.stream() == nullptr) {
TF_ASSIGN_OR_RETURN(
- acquired_stream,
- StreamManager::AcquireStream(backend_, options.device_ordinal()));
- actual_options.set_stream(acquired_stream->stream());
+ stream, BorrowStreamForDevice(options.device_ordinal(), backend_));
+ actual_options.set_stream(stream.get());
}
if (options.allocator() == nullptr) {
actual_options.set_allocator(backend_->memory_allocator());
@@ -222,12 +201,11 @@ tensorflow::Status LocalExecutable::Run(
}
ExecutableRunOptions actual_options = options;
- std::unique_ptr<StreamManager> acquired_stream;
+ Backend::StreamPtr stream;
if (options.stream() == nullptr) {
TF_ASSIGN_OR_RETURN(
- acquired_stream,
- StreamManager::AcquireStream(backend_, options.device_ordinal()));
- actual_options.set_stream(acquired_stream->stream());
+ stream, BorrowStreamForDevice(options.device_ordinal(), backend_));
+ actual_options.set_stream(stream.get());
}
if (options.allocator() == nullptr) {
actual_options.set_allocator(backend_->memory_allocator());
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 7f9c95607b..a3c7ca9906 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -170,6 +170,7 @@ cc_library(
":compiler",
":device_memory_allocator",
":platform_util",
+ ":pool",
":transfer_manager",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@@ -379,6 +380,7 @@ cc_library(
hdrs = ["execution_tracker.h"],
deps = [
":backend",
+ ":pool",
"//tensorflow/compiler/xla:executable_run_options",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
@@ -1286,6 +1288,25 @@ cc_test(
],
)
+cc_library(
+ name = "pool",
+ hdrs = ["pool.h"],
+ deps = [
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "pool_test",
+ srcs = ["pool_test.cc"],
+ deps = [
+ ":pool",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/core:test_main",
+ ],
+)
+
# -----------------------------------------------------------------------------
filegroup(
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index 7452a7b696..e58987635f 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -83,40 +83,26 @@ Backend::CreateDefaultBackend() {
}
tensorflow::Status Backend::PoolStreams(int n, se::StreamExecutor* executor) {
- std::vector<std::unique_ptr<se::Stream>> primed;
+ std::vector<StreamPtr> primed;
for (int i = 0; i < n; ++i) {
- TF_ASSIGN_OR_RETURN(auto stream, AcquireStream(executor));
+ TF_ASSIGN_OR_RETURN(auto stream, BorrowStream(executor));
primed.emplace_back(std::move(stream));
}
- for (int i = 0; i < n; ++i) {
- ReleaseStream(std::move(primed.back()));
- primed.pop_back();
- }
return tensorflow::Status::OK();
}
-StatusOr<std::unique_ptr<perftools::gputools::Stream>> Backend::AcquireStream(
- perftools::gputools::StreamExecutor* executor) {
- tensorflow::mutex_lock lock(mutex_);
- auto& cached_streams = cached_streams_[executor];
- if (!cached_streams.empty()) {
- auto result = std::move(cached_streams.back());
- cached_streams.pop_back();
- return std::move(result);
- }
-
- auto stream = MakeUnique<se::Stream>(executor);
- if (!stream->Init().ok()) {
- return InternalError("failed to initialize stream");
+StatusOr<Backend::StreamPtr> Backend::BorrowStream(
+ se::StreamExecutor* executor) {
+ if (0 == stream_pools_.count(executor)) {
+ stream_pools_.emplace(std::piecewise_construct,
+ std::forward_as_tuple(executor),
+ std::forward_as_tuple([executor]() {
+ auto stream = MakeUnique<se::Stream>(executor);
+ stream->Init();
+ return stream;
+ }));
}
- return std::move(stream);
-}
-
-void Backend::ReleaseStream(
- std::unique_ptr<perftools::gputools::Stream> stream) {
- tensorflow::mutex_lock lock(mutex_);
- auto& streams = cached_streams_[stream->parent()];
- streams.emplace_back(std::move(stream));
+ return stream_pools_.at(executor).Allocate();
}
Backend::Backend(
diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h
index db482c09ae..9461004c4f 100644
--- a/tensorflow/compiler/xla/service/backend.h
+++ b/tensorflow/compiler/xla/service/backend.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
+#include "tensorflow/compiler/xla/service/pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -43,12 +44,11 @@ namespace xla {
//
// It also offers a pooling API for creation/use of initialized streams:
//
-// std::unique_ptr<se::Stream> stream =
-// backend->AcquireStream().ConsumeValueOrDie();
-// // ... use stream ...
-// backend->ReleaseStream(std::move(stream));
+// StreamPtr stream = backend->BorrowStream().ConsumeValueOrDie();
class Backend {
public:
+ using StreamPtr = Pool<perftools::gputools::Stream>::SmartPtr;
+
// The number of streams we create for the pool at initialization time.
static constexpr int kInitialStreamsToPool = 8;
@@ -108,23 +108,17 @@ class Backend {
return stream_executors_[0];
}
- // Primes the internal pool of streams for AcquireStream/ReleaseStream with n
- // initialized stream instances.
+ // Primes the internal pool of streams for BorrowStream with n initialized
+ // stream instances.
tensorflow::Status PoolStreams(int n,
perftools::gputools::StreamExecutor* executor);
- // Acquires a stream for use by the caller, either by grabbing it from an
+ // Borrows a stream for use by the caller, either by grabbing it from an
// internal pool, or by constructing/initializating it, and returns the result
// to the caller.
- //
- // TODO(b/32989582): Return std::unique_ptr with custom deleter.
- StatusOr<std::unique_ptr<perftools::gputools::Stream>> AcquireStream(
+ StatusOr<StreamPtr> BorrowStream(
perftools::gputools::StreamExecutor* executor);
- // Releases a stream from the caller to the internal pool, for use with the
- // paired AcquireStream above.
- void ReleaseStream(std::unique_ptr<perftools::gputools::Stream> stream);
-
// Returns whether the given device ordinal of the backend is supported.
bool device_ordinal_supported(int device_ordinal) const {
return (device_ordinal >= 0 && device_ordinal < device_count() &&
@@ -170,14 +164,10 @@ class Backend {
// Vector of stream executors. stream_executors_[0] is the default executor.
std::vector<perftools::gputools::StreamExecutor*> stream_executors_;
- // Guards the mutable state in the backend object.
- tensorflow::mutex mutex_;
-
- // Mapping from stream executor to cached streams, used by
- // AcquireStream/ReleaseStream above.
+ // Mapping from stream executor to stream pools, used by `BorrowStream` above.
std::map<perftools::gputools::StreamExecutor*,
- std::vector<std::unique_ptr<perftools::gputools::Stream>>>
- cached_streams_ GUARDED_BY(mutex_);
+ Pool<perftools::gputools::Stream>>
+ stream_pools_;
// The default memory allocator to use.
std::unique_ptr<StreamExecutorMemoryAllocator> memory_allocator_;
diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc
index cf1870580c..8d79d07f94 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.cc
+++ b/tensorflow/compiler/xla/service/execution_tracker.cc
@@ -24,10 +24,10 @@ limitations under the License.
namespace xla {
-AsyncExecution::AsyncExecution(
- Backend* backend,
- std::vector<std::unique_ptr<perftools::gputools::Stream>> streams,
- const ExecutionProfile& profile, GlobalDataHandle result)
+AsyncExecution::AsyncExecution(Backend* backend,
+ std::vector<Backend::StreamPtr> streams,
+ const ExecutionProfile& profile,
+ GlobalDataHandle result)
: backend_(CHECK_NOTNULL(backend)),
streams_(std::move(streams)),
profile_(profile),
@@ -37,12 +37,6 @@ AsyncExecution::AsyncExecution(
}
}
-AsyncExecution::~AsyncExecution() {
- for (auto& stream : streams_) {
- backend_->ReleaseStream(std::move(stream));
- }
-}
-
tensorflow::Status AsyncExecution::BlockUntilDone() const {
for (auto& stream : streams_) {
if (!stream->BlockHostUntilDone()) {
@@ -55,8 +49,7 @@ tensorflow::Status AsyncExecution::BlockUntilDone() const {
ExecutionTracker::ExecutionTracker() : next_handle_(1) {}
ExecutionHandle ExecutionTracker::Register(
- Backend* backend,
- std::vector<std::unique_ptr<perftools::gputools::Stream>> streams,
+ Backend* backend, std::vector<Backend::StreamPtr> streams,
const ExecutionProfile& profile, GlobalDataHandle result) {
tensorflow::mutex_lock lock(execution_mutex_);
int64 handle = next_handle_++;
diff --git a/tensorflow/compiler/xla/service/execution_tracker.h b/tensorflow/compiler/xla/service/execution_tracker.h
index 99a5bb5ad9..5b6bddf9f1 100644
--- a/tensorflow/compiler/xla/service/execution_tracker.h
+++ b/tensorflow/compiler/xla/service/execution_tracker.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/executable_run_options.h"
#include "tensorflow/compiler/xla/service/backend.h"
+#include "tensorflow/compiler/xla/service/pool.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@@ -39,12 +40,9 @@ namespace xla {
// the stream when destructed.
class AsyncExecution {
public:
- AsyncExecution(
- Backend* backend,
- std::vector<std::unique_ptr<perftools::gputools::Stream>> streams,
- const ExecutionProfile& profile, GlobalDataHandle result);
+ AsyncExecution(Backend* backend, std::vector<Backend::StreamPtr> streams,
+ const ExecutionProfile& profile, GlobalDataHandle result);
- ~AsyncExecution();
tensorflow::Status BlockUntilDone() const;
const GlobalDataHandle& result() const { return result_; }
@@ -56,7 +54,7 @@ class AsyncExecution {
Backend* backend_;
// Stream on which the execution is launched.
- std::vector<std::unique_ptr<perftools::gputools::Stream>> streams_;
+ std::vector<Backend::StreamPtr> streams_;
// Profile object of the execution to be returned to the user.
ExecutionProfile profile_;
@@ -73,10 +71,10 @@ class ExecutionTracker {
// Registers an execution with its backend, streams, and data handle to the
// execution result. Returns a handle for the registered execution.
- ExecutionHandle Register(
- Backend* backend,
- std::vector<std::unique_ptr<perftools::gputools::Stream>> stream,
- const ExecutionProfile& profile, GlobalDataHandle data);
+ ExecutionHandle Register(Backend* backend,
+ std::vector<Backend::StreamPtr> stream,
+ const ExecutionProfile& profile,
+ GlobalDataHandle data);
// Unregisters the execution for the given handle.
tensorflow::Status Unregister(const ExecutionHandle& handle);
diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc
index 73d6305362..402cc2b615 100644
--- a/tensorflow/compiler/xla/service/local_service.cc
+++ b/tensorflow/compiler/xla/service/local_service.cc
@@ -426,8 +426,8 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
run_options.set_intra_op_thread_pool(
execute_backend_->eigen_intra_op_thread_pool_device());
- // "acquired_stream" owns the stream used for execution if no stream is given.
- std::unique_ptr<se::Stream> acquired_stream;
+ // "stream" owns the stream used for execution if no stream is given.
+ Backend::StreamPtr stream;
if (options.stream()) {
run_options.set_stream(options.stream());
} else {
@@ -439,16 +439,10 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
} else {
stream_executor = execute_backend_->default_stream_executor();
}
- TF_ASSIGN_OR_RETURN(acquired_stream,
- execute_backend_->AcquireStream(stream_executor));
- run_options.set_stream(acquired_stream.get());
+ TF_ASSIGN_OR_RETURN(stream,
+ execute_backend_->BorrowStream(stream_executor));
+ run_options.set_stream(stream.get());
}
- auto stream_releaser =
- ::tensorflow::gtl::MakeCleanup([this, &acquired_stream]() {
- if (acquired_stream != nullptr) {
- execute_backend_->ReleaseStream(std::move(acquired_stream));
- }
- });
ExecutionProfile* profile = options.execution_profile();
TF_ASSIGN_OR_RETURN(
diff --git a/tensorflow/compiler/xla/service/pool.h b/tensorflow/compiler/xla/service/pool.h
new file mode 100644
index 0000000000..8e710ebb6d
--- /dev/null
+++ b/tensorflow/compiler/xla/service/pool.h
@@ -0,0 +1,84 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_POOL_H_
+#define TENSORFLOW_COMPILER_XLA_POOL_H_
+
+#include <functional>
+#include <vector>
+
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace xla {
+
+// Pool of values, which are created as needed and destroyed when the `Pool` is
+// destroyed
+template <typename T>
+class Pool {
+ public:
+ struct Deleter {
+ void operator()(T* ptr) { pool->Deallocate(ptr); }
+
+ Pool<T>* pool;
+ };
+
+ // A pointer to a taken element of a `Pool` which returns it to the pool on
+ // destruction
+ using SmartPtr = std::unique_ptr<T, Deleter>;
+
+ // Constructs a `Pool` with given factory function, which need not be
+ // thread-safe.
+ explicit Pool(std::function<std::unique_ptr<T>()> factory)
+ : factory_(factory) {}
+
+ explicit Pool() : Pool([]() { return MakeUnique<T>(); }) {}
+
+ // Returns a pointer to a value in the pool, creating a new value if none is
+ // free. The returned smart pointer returns the element to the pool on
+ // destruction.
+ //
+ // This method is thread-safe.
+ SmartPtr Allocate() {
+ tensorflow::mutex_lock lock(mu_);
+ T* ptr;
+ if (!xs_.empty()) {
+ ptr = std::move(xs_.back()).release();
+ xs_.pop_back();
+ } else {
+ ptr = factory_().release();
+ }
+ Deleter del = {this};
+ return std::unique_ptr<T, Deleter>(ptr, del);
+ }
+
+ private:
+ // Puts a pointer to a value back into the pool, leaving it free for future
+ // use.
+ //
+ // This method is thread-safe.
+ void Deallocate(T* ptr) {
+ tensorflow::mutex_lock lock(mu_);
+ xs_.push_back(std::unique_ptr<T>(ptr));
+ }
+
+ const std::function<std::unique_ptr<T>()> factory_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<T>> xs_ GUARDED_BY(mu_);
+ tensorflow::mutex mu_;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_POOL_H_
diff --git a/tensorflow/compiler/xla/service/pool_test.cc b/tensorflow/compiler/xla/service/pool_test.cc
new file mode 100644
index 0000000000..8c4fe258e3
--- /dev/null
+++ b/tensorflow/compiler/xla/service/pool_test.cc
@@ -0,0 +1,40 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/pool.h"
+
+#include "tensorflow/compiler/xla/test_helpers.h"
+
+namespace xla {
+namespace {
+
+using PoolTest = ::testing::Test;
+
+TEST_F(PoolTest, Test) {
+ Pool<int> pool;
+
+ {
+ auto ptr = pool.Allocate();
+ EXPECT_NE(nullptr, ptr.get());
+ *ptr = 5;
+ }
+
+ auto ptr = pool.Allocate();
+ EXPECT_NE(nullptr, ptr.get());
+ EXPECT_EQ(5, *ptr);
+}
+
+} // namespace
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 95bbb01e9e..d5b8457b65 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -498,24 +498,17 @@ Service::ExecuteParallelAndRegisterResult(
TF_RET_CHECK(backend->Replicas().size() == 1);
// Set up streams.
- std::vector<std::unique_ptr<se::Stream>> streams;
-
- auto stream_releaser = ::tensorflow::gtl::MakeCleanup([backend, &streams]() {
- for (std::unique_ptr<se::Stream>& stream : streams) {
- backend->ReleaseStream(std::move(stream));
- }
- });
+ std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : executors) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
- backend->AcquireStream(executor));
- // Push back after so that the releaser only sees real streams.
+ TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
+ backend->BorrowStream(executor));
streams.push_back(std::move(stream));
}
// Set up run options.
std::vector<ExecutableRunOptions> run_options;
- for (const std::unique_ptr<se::Stream>& stream : streams) {
+ for (const Pool<se::Stream>::SmartPtr& stream : streams) {
run_options.emplace_back();
auto& options = run_options.back();
options.set_stream(stream.get());
@@ -555,24 +548,17 @@ StatusOr<GlobalDataHandle> Service::ExecuteAndRegisterResult(
TF_RET_CHECK(!backend->Replicas().empty());
// Set up streams.
- std::vector<std::unique_ptr<se::Stream>> streams;
-
- auto stream_releaser = ::tensorflow::gtl::MakeCleanup([backend, &streams]() {
- for (std::unique_ptr<se::Stream>& stream : streams) {
- backend->ReleaseStream(std::move(stream));
- }
- });
+ std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : backend->Replicas()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
- backend->AcquireStream(executor));
- // Push back after so that the releaser only sees real streams.
+ TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
+ backend->BorrowStream(executor));
streams.push_back(std::move(stream));
}
// Set up run options.
std::vector<ExecutableRunOptions> run_options;
- for (const std::unique_ptr<se::Stream>& stream : streams) {
+ for (const Pool<se::Stream>::SmartPtr& stream : streams) {
run_options.emplace_back();
auto& options = run_options.back();
options.set_stream(stream.get());
@@ -851,23 +837,16 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
TF_RET_CHECK(!execute_backend_->Replicas().empty());
// Set up streams.
- std::vector<std::unique_ptr<se::Stream>> streams;
-
- auto stream_releaser = ::tensorflow::gtl::MakeCleanup([this, &streams]() {
- for (std::unique_ptr<se::Stream>& stream : streams) {
- execute_backend_->ReleaseStream(std::move(stream));
- }
- });
+ std::vector<Pool<se::Stream>::SmartPtr> streams;
for (se::StreamExecutor* executor : execute_backend_->Replicas()) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<se::Stream> stream,
- execute_backend_->AcquireStream(executor));
- // Push back after so that the releaser only sees real streams.
+ TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
+ execute_backend_->BorrowStream(executor));
streams.push_back(std::move(stream));
}
perftools::gputools::DeviceMemoryBase result_data;
- for (const std::unique_ptr<se::Stream>& stream : streams) {
+ for (const Pool<se::Stream>::SmartPtr& stream : streams) {
ExecutableRunOptions options;
options.set_stream(stream.get());
options.set_allocator(execute_backend_->memory_allocator());