aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform.cc2
-rw-r--r--tensorflow/stream_executor/host/host_platform.cc2
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc5
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h7
4 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc
index 62dced5604..9f58be4302 100644
--- a/tensorflow/stream_executor/cuda/cuda_platform.cc
+++ b/tensorflow/stream_executor/cuda/cuda_platform.cc
@@ -148,7 +148,7 @@ port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = port::MakeUnique<StreamExecutor>(
- this, new CUDAExecutor(config.plugin_config));
+ this, port::MakeUnique<CUDAExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc
index 59b4bfe5f0..e93ccff4d8 100644
--- a/tensorflow/stream_executor/host/host_platform.cc
+++ b/tensorflow/stream_executor/host/host_platform.cc
@@ -85,7 +85,7 @@ port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
auto executor = port::MakeUnique<StreamExecutor>(
- this, new HostExecutor(config.plugin_config));
+ this, port::MakeUnique<HostExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 71a5a45b67..c498eecb3c 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -164,9 +164,10 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind,
}
StreamExecutor::StreamExecutor(
- const Platform *platform, internal::StreamExecutorInterface *implementation)
+ const Platform *platform,
+ std::unique_ptr<internal::StreamExecutorInterface> implementation)
: platform_(platform),
- implementation_(implementation),
+ implementation_(std::move(implementation)),
device_ordinal_(-1),
background_threads_(new port::ThreadPool(
port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 83fd27599e..a5da0e047e 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_
#include <atomic>
+#include <memory>
#include <set>
#include <tuple>
#include <vector>
@@ -71,8 +72,10 @@ class StreamExecutor {
public:
explicit StreamExecutor(PlatformKind kind,
const PluginConfig &plugin_config = PluginConfig());
- StreamExecutor(const Platform *platform,
- internal::StreamExecutorInterface *implementation);
+
+ StreamExecutor(
+ const Platform *platform,
+ std::unique_ptr<internal::StreamExecutorInterface> implementation);
~StreamExecutor();