diff options
4 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc index 62dced5604..9f58be4302 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.cc +++ b/tensorflow/stream_executor/cuda/cuda_platform.cc @@ -148,7 +148,7 @@ port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor( port::StatusOr<std::unique_ptr<StreamExecutor>> CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = port::MakeUnique<StreamExecutor>( - this, new CUDAExecutor(config.plugin_config)); + this, port::MakeUnique<CUDAExecutor>(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc index 59b4bfe5f0..e93ccff4d8 100644 --- a/tensorflow/stream_executor/host/host_platform.cc +++ b/tensorflow/stream_executor/host/host_platform.cc @@ -85,7 +85,7 @@ port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor( port::StatusOr<std::unique_ptr<StreamExecutor>> HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { auto executor = port::MakeUnique<StreamExecutor>( - this, new HostExecutor(config.plugin_config)); + this, port::MakeUnique<HostExecutor>(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 71a5a45b67..c498eecb3c 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -164,9 +164,10 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind, } StreamExecutor::StreamExecutor( - const Platform *platform, internal::StreamExecutorInterface *implementation) + const Platform *platform, + std::unique_ptr<internal::StreamExecutorInterface> implementation) : platform_(platform), - implementation_(implementation), + implementation_(std::move(implementation)), device_ordinal_(-1), background_threads_(new port::ThreadPool( port::Env::Default(), "stream_executor", kNumBackgroundThreads)), diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 83fd27599e..a5da0e047e 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_PIMPL_H_ #include <atomic> +#include <memory> #include <set> #include <tuple> #include <vector> @@ -71,8 +72,10 @@ class StreamExecutor { public: explicit StreamExecutor(PlatformKind kind, const PluginConfig &plugin_config = PluginConfig()); - StreamExecutor(const Platform *platform, - internal::StreamExecutorInterface *implementation); + + StreamExecutor( + const Platform *platform, + std::unique_ptr<internal::StreamExecutorInterface> implementation); ~StreamExecutor(); |