diff options
author | Justin Lebar <jlebar@google.com> | 2018-04-23 17:16:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-23 17:19:22 -0700 |
commit | 80fc661853f9a0844faf95eb68438dc85a5879e3 (patch) | |
tree | e7328a367c4f4db472063f96c983b70e63aa3f26 /tensorflow/contrib/nccl | |
parent | ecd837fd0ab69cf54d920eae3b1c73602be6c626 (diff) |
Use tensorflow::se instead of perftools::gputools for StreamExecutor.
PiperOrigin-RevId: 194010749
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager.cc | 56 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager.h | 36 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_manager_test.cc | 8 |
3 files changed, 44 insertions, 56 deletions
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc index b9b482a698..b1cb89391c 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc @@ -24,7 +24,7 @@ limitations under the License. namespace tensorflow { -using ::perftools::gputools::cuda::ScopedActivateExecutorContext; +using se::cuda::ScopedActivateExecutorContext; // Contains data for a single stream used for nccl communication; this includes // a background thread that calls NcclManager::LoopKernelLaunches. @@ -37,11 +37,11 @@ struct NcclManager::NcclStream { cv.notify_all(); } - perftools::gputools::StreamExecutor* executor = nullptr; + se::StreamExecutor* executor = nullptr; // The stream on which to run the nccl collective. // This is a different stream than the tensorflow compute stream. - std::unique_ptr<perftools::gputools::Stream> stream; + std::unique_ptr<se::Stream> stream; // See NcclManager::LoopKernelLaunches for information on these. std::unique_ptr<Thread> thread; @@ -95,9 +95,8 @@ ncclDataType_t ToNcclType(DataType t) { // A participant in a Collective. See <Collective> below. struct NcclManager::Participant { Participant(const Tensor* in_t, Tensor* out_t, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, - perftools::gputools::StreamExecutor* executor, int gpu_device_id, - NcclManager::DoneCallback done_callback) + se::Stream* tensor_stream, se::StreamExecutor* executor, + int gpu_device_id, NcclManager::DoneCallback done_callback) : in_t(in_t), out_t(out_t), event_mgr(event_mgr), @@ -121,11 +120,11 @@ struct NcclManager::Participant { EventMgr* const event_mgr; // Owned by the caller, who must keep it live until <done_callback> is called. - perftools::gputools::Stream* const tensor_stream; + se::Stream* const tensor_stream; // Matches the executor in CommunicatorMember::stream. Expected to be live for // process lifetime. - perftools::gputools::StreamExecutor* const executor = nullptr; + se::StreamExecutor* const executor = nullptr; const int gpu_device_id; @@ -245,7 +244,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator( if (nccl_stream == nullptr) { nccl_stream = new NcclStream(); nccl_stream->executor = executor; - nccl_stream->stream.reset(new perftools::gputools::Stream(executor)); + nccl_stream->stream.reset(new se::Stream(executor)); nccl_stream->stream->Init(); streams.emplace_back(nccl_stream); @@ -300,10 +299,10 @@ NcclManager::Communicator* NcclManager::GetCommunicator( void NcclManager::AddToAllReduce(int num_devices, const string& key, ncclRedOp_t reduction_op, - perftools::gputools::StreamExecutor* executor, + se::StreamExecutor* executor, int gpu_device_id, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, - const Tensor* in_t, Tensor* out_t, + se::Stream* tensor_stream, const Tensor* in_t, + Tensor* out_t, const DoneCallback& done_callback) { std::unique_ptr<Participant> participant( new Participant(in_t, out_t, event_mgr, tensor_stream, executor, @@ -312,11 +311,12 @@ void NcclManager::AddToAllReduce(int num_devices, const string& key, kAllReduce, reduction_op); } -void NcclManager::AddBroadcastSend( - int num_devices, const string& key, - perftools::gputools::StreamExecutor* executor, int gpu_device_id, - EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, - const Tensor* in_t, DoneCallback done_callback) { +void NcclManager::AddBroadcastSend(int num_devices, const string& key, + se::StreamExecutor* executor, + int gpu_device_id, EventMgr* event_mgr, + se::Stream* tensor_stream, + const Tensor* in_t, + DoneCallback done_callback) { std::unique_ptr<Participant> participant( new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream, executor, gpu_device_id, std::move(done_callback))); @@ -325,11 +325,11 @@ void NcclManager::AddBroadcastSend( kBroadcast, ncclSum /* unused */); } -void NcclManager::AddBroadcastRecv( - int num_devices, const string& key, - perftools::gputools::StreamExecutor* executor, int gpu_device_id, - EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream, - Tensor* out_t, DoneCallback done_callback) { +void NcclManager::AddBroadcastRecv(int num_devices, const string& key, + se::StreamExecutor* executor, + int gpu_device_id, EventMgr* event_mgr, + se::Stream* tensor_stream, Tensor* out_t, + DoneCallback done_callback) { std::unique_ptr<Participant> participant( new Participant(nullptr /* in_t */, out_t, event_mgr, tensor_stream, executor, gpu_device_id, std::move(done_callback))); @@ -339,9 +339,8 @@ void NcclManager::AddBroadcastRecv( void NcclManager::AddReduceSend(int num_devices, const string& key, ncclRedOp_t reduction_op, - perftools::gputools::StreamExecutor* executor, - int gpu_device_id, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, + se::StreamExecutor* executor, int gpu_device_id, + EventMgr* event_mgr, se::Stream* tensor_stream, const Tensor* in_t, DoneCallback done_callback) { std::unique_ptr<Participant> participant( @@ -353,9 +352,8 @@ void NcclManager::AddReduceSend(int num_devices, const string& key, void NcclManager::AddReduceRecv(int num_devices, const string& key, ncclRedOp_t reduction_op, - perftools::gputools::StreamExecutor* executor, - int gpu_device_id, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, + se::StreamExecutor* executor, int gpu_device_id, + EventMgr* event_mgr, se::Stream* tensor_stream, const Tensor* in_t, Tensor* out_t, DoneCallback done_callback) { std::unique_ptr<Participant> participant( @@ -444,7 +442,7 @@ void NcclManager::RunCollective(const string& key, Collective* collective) { } void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) { - perftools::gputools::Stream* comm_stream = nccl_stream->stream.get(); + se::Stream* comm_stream = nccl_stream->stream.get(); ScopedActivateExecutorContext scoped_context(nccl_stream->executor); const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>( comm_stream->implementation()->CudaStreamMemberHack()); diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h index 6ff8cea84e..57a96c5d33 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager.h +++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h @@ -55,41 +55,34 @@ class NcclManager { // is also the stream that will use the produced data; <done_callback> is // not called until the next kernel launched on <stream> would see the data. void AddToAllReduce(int num_devices, const string& key, - ncclRedOp_t reduction_op, - perftools::gputools::StreamExecutor* executor, + ncclRedOp_t reduction_op, se::StreamExecutor* executor, int gpu_device_id, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, - const Tensor* in_t, Tensor* out_t, - const DoneCallback& done_callback); + se::Stream* tensor_stream, const Tensor* in_t, + Tensor* out_t, const DoneCallback& done_callback); // AddBroadcastSend and AddBroadcastRecv combine to sent data from one sender // to all receivers. void AddBroadcastSend(int num_devices, const string& key, - perftools::gputools::StreamExecutor* executor, - int gpu_device_id, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, + se::StreamExecutor* executor, int gpu_device_id, + EventMgr* event_mgr, se::Stream* tensor_stream, const Tensor* in_t, DoneCallback done_callback); void AddBroadcastRecv(int num_devices, const string& key, - perftools::gputools::StreamExecutor* executor, - int gpu_device_id, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, + se::StreamExecutor* executor, int gpu_device_id, + EventMgr* event_mgr, se::Stream* tensor_stream, Tensor* out_t, DoneCallback done_callback); // AddReduceSend and AddReduceRecv combine to sent data from all senders // to one receiver. void AddReduceSend(int num_devices, const string& key, - ncclRedOp_t reduction_op, - perftools::gputools::StreamExecutor* executor, + ncclRedOp_t reduction_op, se::StreamExecutor* executor, int gpu_device_id, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, - const Tensor* in_t, DoneCallback done_callback); + se::Stream* tensor_stream, const Tensor* in_t, + DoneCallback done_callback); void AddReduceRecv(int num_devices, const string& key, - ncclRedOp_t reduction_op, - perftools::gputools::StreamExecutor* executor, + ncclRedOp_t reduction_op, se::StreamExecutor* executor, int gpu_device_id, EventMgr* event_mgr, - perftools::gputools::Stream* tensor_stream, - const Tensor* in_t, Tensor* out_t, - DoneCallback done_callback); + se::Stream* tensor_stream, const Tensor* in_t, + Tensor* out_t, DoneCallback done_callback); private: enum CollectiveType { @@ -123,8 +116,7 @@ class NcclManager { // Maps a device to the communication streams that make up its collective. // This is used to share the stream across different communicators that // include the same device. - std::map<perftools::gputools::StreamExecutor*, - std::vector<std::unique_ptr<NcclStream>>> + std::map<se::StreamExecutor*, std::vector<std::unique_ptr<NcclStream>>> device_to_comm_streams_ GUARDED_BY(mu_); std::vector<std::unique_ptr<Communicator>> communicators_; diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc index 06ca65e33a..4d8d922cb4 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc @@ -175,11 +175,9 @@ class NcclManagerTest : public ::testing::Test { nullptr /* step_resource_manager */); } - static perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory( - const Scalar* cuda_memory) { - perftools::gputools::DeviceMemoryBase wrapped( - const_cast<Scalar*>(cuda_memory)); - perftools::gputools::DeviceMemory<Scalar> typed(wrapped); + static se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) { + se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory)); + se::DeviceMemory<Scalar> typed(wrapped); return typed; } |