diff options
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.cc | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 9515d8e62a..10bf006787 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -22,6 +22,7 @@ limitations under the License. #include <atomic> #include <utility> +#include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/fft.h" #include "tensorflow/stream_executor/lib/env.h" @@ -163,6 +164,15 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind, CheckPlatformKindIsValid(platform_kind); } +// Get per-device memory limit in bytes. Returns 0 if +// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set. +static int64 GetMemoryLimitBytes() { + int64 value; + SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB", + 0, &value)); + return value * (1ll << 20); +} + StreamExecutor::StreamExecutor( const Platform *platform, std::unique_ptr<internal::StreamExecutorInterface> implementation) @@ -172,7 +182,9 @@ StreamExecutor::StreamExecutor( background_threads_(new port::ThreadPool( port::Env::Default(), "stream_executor", kNumBackgroundThreads)), live_stream_count_(0), - tracing_enabled_(false) { + tracing_enabled_(false), + mem_alloc_bytes_(0), + memory_limit_bytes_(GetMemoryLimitBytes()) { if (port::Lowercase(platform_->Name()) == "cuda") { platform_kind_ = PlatformKind::kCuda; } else if (port::Lowercase(platform_->Name()) == "opencl") { @@ -460,6 +472,14 @@ port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) { } void *StreamExecutor::Allocate(uint64 size) { + if (memory_limit_bytes_ > 0 && + mem_alloc_bytes_ + size > memory_limit_bytes_) { + LOG(WARNING) << "Not enough memory to allocate " << size << " on device " + << device_ordinal_ + << " within provided limit. [used=" << mem_alloc_bytes_ + << ", limit=" << memory_limit_bytes_ << "]"; + return nullptr; + } void *buf = implementation_->Allocate(size); VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns " << buf << StackTraceIfVLOG10(); @@ -779,6 +799,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) { mutex_lock lock(mu_); mem_allocs_[opaque] = AllocRecord{ bytes, ""}; + mem_alloc_bytes_ += bytes; } } @@ -789,6 +810,7 @@ void StreamExecutor::EraseAllocRecord(void *opaque) { LOG(ERROR) << "Deallocating unknown pointer: " << port::Printf("0x%p", opaque); } else { + mem_alloc_bytes_ -= mem_allocs_[opaque].bytes; mem_allocs_.erase(opaque); } } |