/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // Implements the StreamExecutor interface by passing through to its // implementation_ value (in pointer-to-implementation style), which // implements StreamExecutorInterface. #include "tensorflow/stream_executor/stream_executor_pimpl.h" #include #include #include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/fft.h" #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/notification.h" #include "tensorflow/stream_executor/lib/stacktrace.h" #include "tensorflow/stream_executor/lib/str_util.h" #include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/rng.h" #include "tensorflow/stream_executor/stream_executor_internal.h" namespace { bool FLAGS_check_device_leaks = false; } // namespace namespace stream_executor { namespace { string StackTraceIfVLOG10() { if (VLOG_IS_ON(10)) { return port::StrCat(" ", port::CurrentStackTrace(), "\n"); } else { return ""; } } // Make sure the executor is done with its work; we know (because this isn't // publicly visible) that all enqueued work is quick. void BlockOnThreadExecutor(port::ThreadPool *executor) { port::Notification n; executor->Schedule([&n]() { n.Notify(); }); n.WaitForNotification(); } internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind( PlatformKind platform_kind, const PluginConfig &plugin_config) { // Note: we use this factory-assignment-in-switch pattern instead of just // invoking the callable in case linkage is messed up -- instead of invoking a // nullptr std::function (due to failed registration) we give a nice // LOG(FATAL) message. internal::StreamExecutorFactory factory; switch (platform_kind) { case PlatformKind::kCuda: factory = *internal::MakeCUDAExecutorImplementation(); break; case PlatformKind::kOpenCL: factory = *internal::MakeOpenCLExecutorImplementation(); break; case PlatformKind::kHost: factory = internal::MakeHostExecutorImplementation; break; default: factory = nullptr; } if (factory == nullptr) { LOG(FATAL) << "cannot create StreamExecutor implementation for platform kind: " << PlatformKindString(platform_kind); } return factory(plugin_config); } std::atomic_int_fast64_t correlation_id_generator(0); } // namespace template class ScopedTracer { public: ScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call, CompleteCallT complete_call, const ReturnT *result, BeginArgsT... begin_args) : stream_exec_(stream_exec), complete_call_(complete_call), result_(result) { if (stream_exec_->tracing_enabled_) { correlation_id_ = correlation_id_generator.fetch_add(1, std::memory_order_relaxed) - 1; Trace(begin_call, begin_args...); } } ~ScopedTracer() { if (stream_exec_->tracing_enabled_) { Trace(complete_call_, result_); } } private: template void Trace(CallbackT callback, TraceArgsT... args) { { // Instance tracers held in a block to limit the lock lifetime. tf_shared_lock lock{stream_exec_->mu_}; for (TraceListener *listener : stream_exec_->listeners_) { (listener->*callback)(correlation_id_, std::forward(args)...); } } } StreamExecutor *stream_exec_; CompleteCallT complete_call_; const ReturnT* result_; int64 correlation_id_; }; template ScopedTracer MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call, CompleteCallT complete_call, ReturnT *result, BeginArgsT... begin_args) { return ScopedTracer( stream_exec, begin_call, complete_call, result, std::forward(begin_args)...); } #define SCOPED_TRACE(LOC, ...) \ auto tracer = MakeScopedTracer(this, &LOC ## Begin, \ &LOC ## Complete, ## __VA_ARGS__); /* static */ mutex StreamExecutor::static_mu_{LINKER_INITIALIZED}; StreamExecutor::StreamExecutor(PlatformKind platform_kind, const PluginConfig &plugin_config) : platform_(nullptr), implementation_(StreamExecutorImplementationFromPlatformKind( platform_kind, plugin_config)), platform_kind_(platform_kind), device_ordinal_(-1), background_threads_(new port::ThreadPool( port::Env::Default(), "stream_executor", kNumBackgroundThreads)), live_stream_count_(0), tracing_enabled_(false) { CheckPlatformKindIsValid(platform_kind); } // Get per-device memory limit in bytes. Returns 0 if // TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set. static int64 GetMemoryLimitBytes() { int64 value; SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB", 0, &value)); return value * (1ll << 20); } StreamExecutor::StreamExecutor( const Platform *platform, std::unique_ptr implementation) : platform_(platform), implementation_(std::move(implementation)), device_ordinal_(-1), background_threads_(new port::ThreadPool( port::Env::Default(), "stream_executor", kNumBackgroundThreads)), live_stream_count_(0), tracing_enabled_(false), mem_alloc_bytes_(0), memory_limit_bytes_(GetMemoryLimitBytes()) { if (port::Lowercase(platform_->Name()) == "cuda") { platform_kind_ = PlatformKind::kCuda; } else if (port::Lowercase(platform_->Name()) == "opencl") { platform_kind_ = PlatformKind::kOpenCL; } else if (port::Lowercase(platform_->Name()) == "host") { platform_kind_ = PlatformKind::kHost; } } StreamExecutor::~StreamExecutor() { BlockOnThreadExecutor(background_threads_.get()); if (live_stream_count_.load() != 0) { LOG(WARNING) << "Not all streams were deallocated at executor destruction " << "time. This may lead to unexpected/bad behavior - " << "especially if any stream is still active!"; } if (FLAGS_check_device_leaks) { for (auto it : mem_allocs_) { LOG(INFO) << "Memory alloced at executor exit: addr: " << port::Printf("%p", it.first) << ", bytes: " << it.second.bytes << ", trace: \n" << it.second.stack_trace; } } } port::Status StreamExecutor::Init(int device_ordinal, DeviceOptions device_options) { device_ordinal_ = device_ordinal; return implementation_->Init(device_ordinal, std::move(device_options)); } port::Status StreamExecutor::Init() { return Init(0, DeviceOptions::Default()); } bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel) { return implementation_->GetKernel(spec, kernel); } void StreamExecutor::UnloadKernel(const KernelBase *kernel) { implementation_->UnloadKernel(kernel); } bool StreamExecutor::LoadModule(const MultiModuleLoaderSpec &spec, ModuleHandle *module_handle) { return implementation_->LoadModule(spec, module_handle); } bool StreamExecutor::UnloadModule(ModuleHandle module_handle) { return implementation_->UnloadModule(module_handle); } void StreamExecutor::Deallocate(DeviceMemoryBase *mem) { VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque() << ") mem->size()=" << mem->size() << StackTraceIfVLOG10(); if (mem->opaque() != nullptr) { EraseAllocRecord(mem->opaque()); } implementation_->Deallocate(mem); mem->Reset(nullptr, 0); } void StreamExecutor::GetMemAllocs(std::map *records_out) { tf_shared_lock lock(mu_); *records_out = mem_allocs_; } bool StreamExecutor::CanEnablePeerAccessTo(StreamExecutor *other) { return implementation_->CanEnablePeerAccessTo(other->implementation_.get()); } port::Status StreamExecutor::EnablePeerAccessTo(StreamExecutor *other) { return implementation_->EnablePeerAccessTo(other->implementation_.get()); } SharedMemoryConfig StreamExecutor::GetDeviceSharedMemoryConfig() { return implementation_->GetDeviceSharedMemoryConfig(); } port::Status StreamExecutor::SetDeviceSharedMemoryConfig( SharedMemoryConfig config) { if (config != SharedMemoryConfig::kDefault && config != SharedMemoryConfig::kFourByte && config != SharedMemoryConfig::kEightByte) { string error_msg = port::Printf( "Invalid shared memory config specified: %d", static_cast(config)); LOG(ERROR) << error_msg; return port::Status(port::error::INVALID_ARGUMENT, error_msg); } return implementation_->SetDeviceSharedMemoryConfig(config); } const DeviceDescription &StreamExecutor::GetDeviceDescription() const { mutex_lock lock(mu_); if (device_description_ != nullptr) { return *device_description_; } device_description_.reset(PopulateDeviceDescription()); return *device_description_; } int64 StreamExecutor::GetDeviceLoad() const { return implementation_->GetDeviceLoad(); } int StreamExecutor::PlatformDeviceCount() const { return implementation_->PlatformDeviceCount(); } bool StreamExecutor::SupportsBlas() const { return implementation_->SupportsBlas(); } bool StreamExecutor::SupportsRng() const { return implementation_->SupportsRng(); } bool StreamExecutor::SupportsDnn() const { return implementation_->SupportsDnn(); } bool StreamExecutor::GetConvolveAlgorithms( bool with_winograd_nonfused, std::vector *out_algorithms) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { return false; } int cc_major, cc_minor; GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); return dnn_support->GetConvolveAlgorithms(with_winograd_nonfused, cc_major, cc_minor, out_algorithms); } bool StreamExecutor::GetRnnAlgorithms( std::vector *out_algorithms) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { return false; } return dnn_support->GetRnnAlgorithms(out_algorithms); } bool StreamExecutor::GetConvolveBackwardDataAlgorithms( bool with_winograd_nonfused, std::vector *out_algorithms) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { return false; } int cc_major, cc_minor; GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); return dnn_support->GetConvolveBackwardDataAlgorithms( with_winograd_nonfused, cc_major, cc_minor, out_algorithms); } bool StreamExecutor::GetConvolveBackwardFilterAlgorithms( bool with_winograd_nonfused, std::vector *out_algorithms) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { return false; } int cc_major, cc_minor; GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); return dnn_support->GetConvolveBackwardFilterAlgorithms( with_winograd_nonfused, cc_major, cc_minor, out_algorithms); } bool StreamExecutor::GetBlasGemmAlgorithms( std::vector *out_algorithms) { blas::BlasSupport *blas_support = AsBlas(); if (!blas_support) { return false; } return blas_support->GetBlasGemmAlgorithms(out_algorithms); } port::StatusOr> StreamExecutor::createRnnDescriptor( int num_layers, int hidden_size, int input_size, int batch_size, dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode, dnn::RnnMode rnn_mode, dnn::DataType data_type, const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed, ScratchAllocator *state_allocator) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { return port::Status(port::error::UNKNOWN, "Fail to find the dnn implementation."); } return dnn_support->createRnnDescriptor( num_layers, hidden_size, input_size, batch_size, input_mode, direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed, state_allocator); } port::StatusOr> StreamExecutor::createRnnSequenceTensorDescriptor(int seq_length, int batch_size, int data_size, dnn::DataType data_type) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { return port::Status(port::error::UNKNOWN, "Fail to find the dnn implementation."); } return dnn_support->createRnnSequenceTensorDescriptor(seq_length, batch_size, data_size, data_type); } port::StatusOr> StreamExecutor::createRnnStateTensorDescriptor(int num_layer, int batch_size, int data_size, dnn::DataType data_type) { dnn::DnnSupport *dnn_support = AsDnn(); if (!dnn_support) { return port::Status(port::error::UNKNOWN, "Fail to find the dnn implementation."); } return dnn_support->createRnnStateTensorDescriptor(num_layer, batch_size, data_size, data_type); } dnn::DnnSupport *StreamExecutor::AsDnn() { mutex_lock lock(mu_); if (dnn_ != nullptr) { return dnn_.get(); } dnn_.reset(implementation_->CreateDnn()); return dnn_.get(); } blas::BlasSupport *StreamExecutor::AsBlas() { mutex_lock lock(mu_); if (blas_ != nullptr) { return blas_.get(); } blas_.reset(implementation_->CreateBlas()); return blas_.get(); } fft::FftSupport *StreamExecutor::AsFft() { mutex_lock lock(mu_); if (fft_ != nullptr) { return fft_.get(); } fft_.reset(implementation_->CreateFft()); return fft_.get(); } rng::RngSupport *StreamExecutor::AsRng() { mutex_lock lock(mu_); if (rng_ != nullptr) { return rng_.get(); } rng_.reset(implementation_->CreateRng()); return rng_.get(); } bool StreamExecutor::Launch(Stream *stream, const ThreadDim &thread_dims, const BlockDim &block_dims, const KernelBase &kernel, const KernelArgsArrayBase &args) { SubmitTrace(&TraceListener::LaunchSubmit, stream, thread_dims, block_dims, kernel, args); return implementation_->Launch(stream, thread_dims, block_dims, kernel, args); } port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) { port::Status result; SCOPED_TRACE(TraceListener::BlockHostUntilDone, &result, stream); result = implementation_->BlockHostUntilDone(stream); return result; } void *StreamExecutor::Allocate(uint64 size) { if (memory_limit_bytes_ > 0 && mem_alloc_bytes_ + size > memory_limit_bytes_) { LOG(WARNING) << "Not enough memory to allocate " << size << " on device " << device_ordinal_ << " within provided limit. [used=" << mem_alloc_bytes_ << ", limit=" << memory_limit_bytes_ << "]"; return nullptr; } void *buf = implementation_->Allocate(size); VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns " << buf << StackTraceIfVLOG10(); CreateAllocRecord(buf, size); return buf; } port::StatusOr StreamExecutor::GetUntypedSymbol( const string &symbol_name, ModuleHandle module_handle) { // If failed to get the symbol, opaque/bytes are unchanged. Initialize them to // be nullptr/0 for consistency with DeviceMemory semantics. void *opaque = nullptr; size_t bytes = 0; if (GetSymbol(symbol_name, module_handle, &opaque, &bytes)) { return DeviceMemoryBase(opaque, bytes); } if (static_cast(module_handle)) { return port::Status( port::error::NOT_FOUND, port::StrCat("Check if module containing symbol ", symbol_name, " is loaded (module_handle = ", reinterpret_cast(module_handle.id()), ")")); } else { return port::Status( port::error::NOT_FOUND, port::StrCat("Check if kernel using the symbol is loaded: ", symbol_name)); } } bool StreamExecutor::GetSymbol(const string &symbol_name, ModuleHandle module_handle, void **mem, size_t *bytes) { return implementation_->GetSymbol(symbol_name, module_handle, mem, bytes); } void *StreamExecutor::UnifiedMemoryAllocate(uint64 bytes) { void *buffer = implementation_->UnifiedMemoryAllocate(bytes); VLOG(1) << "Called StreamExecutor::UnifiedMemoryAllocate(size=" << bytes << ") returns " << buffer << StackTraceIfVLOG10(); return buffer; } void StreamExecutor::UnifiedMemoryDeallocate(void *location) { VLOG(1) << "Called StreamExecutor::UnifiedMemoryDeallocate(location=" << location << ")" << StackTraceIfVLOG10(); return implementation_->UnifiedMemoryDeallocate(location); } void *StreamExecutor::HostMemoryAllocate(uint64 size) { void *buffer = implementation_->HostMemoryAllocate(size); VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size << ") returns " << buffer << StackTraceIfVLOG10(); return buffer; } void StreamExecutor::HostMemoryDeallocate(void *location) { VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location << ")" << StackTraceIfVLOG10(); return implementation_->HostMemoryDeallocate(location); } bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) { VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location << ", size=" << size << ")" << StackTraceIfVLOG10(); if (location == nullptr || size == 0) { LOG(WARNING) << "attempting to register null or zero-sized memory: " << location << "; size " << size; } return implementation_->HostMemoryRegister(location, size); } bool StreamExecutor::HostMemoryUnregister(void *location) { VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location << ")" << StackTraceIfVLOG10(); return implementation_->HostMemoryUnregister(location); } bool StreamExecutor::SynchronizeAllActivity() { VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()" << StackTraceIfVLOG10(); bool ok = implementation_->SynchronizeAllActivity(); // This should all be quick and infallible work, so we can perform the // synchronization even in the case of failure. BlockOnThreadExecutor(background_threads_.get()); return ok; } bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location, uint64 size) { VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location << ", size=" << size << ")" << StackTraceIfVLOG10(); return implementation_->SynchronousMemZero(location, size); } bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value, uint64 size) { VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location << ", value=" << value << ", size=" << size << ")" << StackTraceIfVLOG10(); return implementation_->SynchronousMemSet(location, value, size); } bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src, uint64 size) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst=" << device_dst->opaque() << ", host_src=" << host_src << ", size=" << size << ") H2D" << StackTraceIfVLOG10(); // Tracing overloaded methods is very difficult due to issues with type // inference on template args. Since use of these overloaded methods is // discouraged anyway, this isn't a huge deal. port::Status status = implementation_->SynchronousMemcpy(device_dst, host_src, size); if (!status.ok()) { LOG(ERROR) << "synchronous memcpy: " << status; } return status.ok(); } bool StreamExecutor::SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src, uint64 size) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst << ", device_src=" << device_src.opaque() << ", size=" << size << ") D2H" << StackTraceIfVLOG10(); port::Status status = implementation_->SynchronousMemcpy(host_dst, device_src, size); if (!status.ok()) { LOG(ERROR) << "synchronous memcpy: " << status; } return status.ok(); } bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *device_dst, const DeviceMemoryBase &device_src, uint64 size) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(device_dst=" << device_dst->opaque() << ", device_src=" << device_src.opaque() << ", size=" << size << ") D2D" << StackTraceIfVLOG10(); port::Status status = implementation_->SynchronousMemcpyDeviceToDevice( device_dst, device_src, size); if (!status.ok()) { LOG(ERROR) << "synchronous memcpy: " << status; } return status.ok(); } port::Status StreamExecutor::SynchronousMemcpyD2H( const DeviceMemoryBase &device_src, int64 size, void *host_dst) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(device_src=" << device_src.opaque() << ", size=" << size << ", host_dst=" << host_dst << ")" << StackTraceIfVLOG10(); port::Status result; SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, &result, device_src, size, host_dst); result = implementation_->SynchronousMemcpy(host_dst, device_src, size); if (!result.ok()) { result = port::Status(port::error::INTERNAL, port::Printf("failed to synchronously memcpy " "device-to-host: device %p to host %p " "size %lld: %s", device_src.opaque(), host_dst, size, result.ToString().c_str())); } return result; } port::Status StreamExecutor::SynchronousMemcpyH2D( const void *host_src, int64 size, DeviceMemoryBase *device_dst) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src << ", size=" << size << ", device_dst=" << device_dst->opaque() << ")" << StackTraceIfVLOG10(); port::Status result; SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, &result, host_src, size, device_dst); result = implementation_->SynchronousMemcpy(device_dst, host_src, size); if (!result.ok()) { result = port::Status( port::error::INTERNAL, port::Printf("failed to synchronously memcpy host-to-device: host " "%p to device %p size %lld: %s", host_src, device_dst->opaque(), size, result.ToString().c_str())); } return result; } bool StreamExecutor::Memcpy(Stream *stream, void *host_dst, const DeviceMemoryBase &device_src, uint64 size) { return implementation_->Memcpy(stream, host_dst, device_src, size); } bool StreamExecutor::Memcpy(Stream *stream, DeviceMemoryBase *device_dst, const void *host_src, uint64 size) { return implementation_->Memcpy(stream, device_dst, host_src, size); } bool StreamExecutor::MemcpyDeviceToDevice(Stream *stream, DeviceMemoryBase *device_dst, const DeviceMemoryBase &device_src, uint64 size) { return implementation_->MemcpyDeviceToDevice(stream, device_dst, device_src, size); } bool StreamExecutor::MemZero(Stream *stream, DeviceMemoryBase *location, uint64 size) { return implementation_->MemZero(stream, location, size); } bool StreamExecutor::Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern, uint64 size) { CHECK_EQ(0, size % 4) << "need 32-bit multiple size to fill with 32-bit pattern"; return implementation_->Memset32(stream, location, pattern, size); } bool StreamExecutor::HostCallback(Stream *stream, std::function callback) { return implementation_->HostCallback(stream, std::move(callback)); } bool StreamExecutor::HostCallback(Stream *stream, std::function callback) { return implementation_->HostCallback(stream, std::move(callback)); } port::Status StreamExecutor::AllocateEvent(Event *event) { return implementation_->AllocateEvent(event); } port::Status StreamExecutor::DeallocateEvent(Event *event) { return implementation_->DeallocateEvent(event); } port::Status StreamExecutor::RecordEvent(Stream *stream, Event *event) { return implementation_->RecordEvent(stream, event); } port::Status StreamExecutor::WaitForEvent(Stream *stream, Event *event) { return implementation_->WaitForEvent(stream, event); } Event::Status StreamExecutor::PollForEventStatus(Event *event) { return implementation_->PollForEventStatus(event); } bool StreamExecutor::AllocateStream(Stream *stream) { live_stream_count_.fetch_add(1, std::memory_order_relaxed); if (!implementation_->AllocateStream(stream)) { auto count = live_stream_count_.fetch_sub(1); CHECK_GE(count, 0) << "live stream count should not dip below zero"; LOG(INFO) << "failed to allocate stream; live stream count: " << count; return false; } return true; } void StreamExecutor::DeallocateStream(Stream *stream) { implementation_->DeallocateStream(stream); CHECK_GE(live_stream_count_.fetch_sub(1), 0) << "live stream count should not dip below zero"; } bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { return implementation_->CreateStreamDependency(dependent, other); } bool StreamExecutor::AllocateTimer(Timer *timer) { return implementation_->AllocateTimer(timer); } void StreamExecutor::DeallocateTimer(Timer *timer) { return implementation_->DeallocateTimer(timer); } bool StreamExecutor::StartTimer(Stream *stream, Timer *timer) { return implementation_->StartTimer(stream, timer); } bool StreamExecutor::StopTimer(Stream *stream, Timer *timer) { return implementation_->StopTimer(stream, timer); } DeviceDescription *StreamExecutor::PopulateDeviceDescription() const { return implementation_->PopulateDeviceDescription(); } bool StreamExecutor::DeviceMemoryUsage(int64 *free, int64 *total) const { return implementation_->DeviceMemoryUsage(free, total); } void StreamExecutor::EnqueueOnBackgroundThread(std::function task) { background_threads_->Schedule(std::move(task)); } void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) { if (FLAGS_check_device_leaks && opaque != nullptr && bytes != 0) { mutex_lock lock(mu_); mem_allocs_[opaque] = AllocRecord{ bytes, ""}; mem_alloc_bytes_ += bytes; } } void StreamExecutor::EraseAllocRecord(void *opaque) { if (FLAGS_check_device_leaks && opaque != nullptr) { mutex_lock lock(mu_); if (mem_allocs_.find(opaque) == mem_allocs_.end()) { LOG(ERROR) << "Deallocating unknown pointer: " << port::Printf("0x%p", opaque); } else { mem_alloc_bytes_ -= mem_allocs_[opaque].bytes; mem_allocs_.erase(opaque); } } } void StreamExecutor::EnableTracing(bool enabled) { tracing_enabled_ = enabled; } void StreamExecutor::RegisterTraceListener(TraceListener *listener) { { mutex_lock lock(mu_); if (listeners_.find(listener) != listeners_.end()) { LOG(INFO) << "Attempt to register already-registered listener, " << listener; } else { listeners_.insert(listener); } } implementation_->RegisterTraceListener(listener); } bool StreamExecutor::UnregisterTraceListener(TraceListener *listener) { { mutex_lock lock(mu_); if (listeners_.find(listener) == listeners_.end()) { LOG(INFO) << "Attempt to unregister unknown listener, " << listener; return false; } listeners_.erase(listener); } implementation_->UnregisterTraceListener(listener); return true; } template void StreamExecutor::SubmitTrace(TraceCallT trace_call, ArgsT &&... args) { if (tracing_enabled_) { { // instance tracers held in a block to limit the lock lifetime. tf_shared_lock lock(mu_); for (TraceListener *listener : listeners_) { (listener->*trace_call)(std::forward(args)...); } } } } internal::StreamExecutorInterface *StreamExecutor::implementation() { return implementation_->GetUnderlyingExecutor(); } } // namespace stream_executor