diff options
author | Anna R <annarev@google.com> | 2018-09-12 12:29:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 12:38:24 -0700 |
commit | 28e945e590b07de137f318a70896bc4fc31f7053 (patch) | |
tree | 59435a08e14420e284ac1c0aed60074bd2a71435 /tensorflow/stream_executor | |
parent | f337425dc71e3ea95aa91ce401a40c1b594486ca (diff) |
Internal change.
PiperOrigin-RevId: 212684548
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.cc | 24 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_pimpl.h | 7 |
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index 9515d8e62a..10bf006787 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -22,6 +22,7 @@ limitations under the License. #include <atomic> #include <utility> +#include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/blas.h" #include "tensorflow/stream_executor/fft.h" #include "tensorflow/stream_executor/lib/env.h" @@ -163,6 +164,15 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind, CheckPlatformKindIsValid(platform_kind); } +// Get per-device memory limit in bytes. Returns 0 if +// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set. +static int64 GetMemoryLimitBytes() { + int64 value; + SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB", + 0, &value)); + return value * (1ll << 20); +} + StreamExecutor::StreamExecutor( const Platform *platform, std::unique_ptr<internal::StreamExecutorInterface> implementation) @@ -172,7 +182,9 @@ StreamExecutor::StreamExecutor( background_threads_(new port::ThreadPool( port::Env::Default(), "stream_executor", kNumBackgroundThreads)), live_stream_count_(0), - tracing_enabled_(false) { + tracing_enabled_(false), + mem_alloc_bytes_(0), + memory_limit_bytes_(GetMemoryLimitBytes()) { if (port::Lowercase(platform_->Name()) == "cuda") { platform_kind_ = PlatformKind::kCuda; } else if (port::Lowercase(platform_->Name()) == "opencl") { @@ -460,6 +472,14 @@ port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) { } void *StreamExecutor::Allocate(uint64 size) { + if (memory_limit_bytes_ > 0 && + mem_alloc_bytes_ + size > memory_limit_bytes_) { + LOG(WARNING) << "Not enough memory to allocate " << size << " on device " + << device_ordinal_ + << " within provided limit. [used=" << mem_alloc_bytes_ + << ", limit=" << memory_limit_bytes_ << "]"; + return nullptr; + } void *buf = implementation_->Allocate(size); VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns " << buf << StackTraceIfVLOG10(); @@ -779,6 +799,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) { mutex_lock lock(mu_); mem_allocs_[opaque] = AllocRecord{ bytes, ""}; + mem_alloc_bytes_ += bytes; } } @@ -789,6 +810,7 @@ void StreamExecutor::EraseAllocRecord(void *opaque) { LOG(ERROR) << "Deallocating unknown pointer: " << port::Printf("0x%p", opaque); } else { + mem_alloc_bytes_ -= mem_allocs_[opaque].bytes; mem_allocs_.erase(opaque); } } diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 437f298616..d04025b681 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -699,6 +699,13 @@ class StreamExecutor { // The set of TraceListeners registered for this StreamExecutor. std::set<TraceListener*> listeners_ GUARDED_BY(mu_); + // Allocated memory in bytes. + int64 mem_alloc_bytes_; + + // Memory limit in bytes. Value less or equal to 0 indicates there is no + // limit. + int64 memory_limit_bytes_; + SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor); }; |