aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/stream_executor_pimpl.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor/stream_executor_pimpl.cc')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc24
1 files changed, 23 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 9515d8e62a..10bf006787 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <atomic>
#include <utility>
+#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/lib/env.h"
@@ -163,6 +164,15 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind,
CheckPlatformKindIsValid(platform_kind);
}
+// Get per-device memory limit in bytes. Returns 0 if
+// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
+static int64 GetMemoryLimitBytes() {
+ int64 value;
+ SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
+ 0, &value));
+ return value * (1ll << 20);
+}
+
StreamExecutor::StreamExecutor(
const Platform *platform,
std::unique_ptr<internal::StreamExecutorInterface> implementation)
@@ -172,7 +182,9 @@ StreamExecutor::StreamExecutor(
background_threads_(new port::ThreadPool(
port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
live_stream_count_(0),
- tracing_enabled_(false) {
+ tracing_enabled_(false),
+ mem_alloc_bytes_(0),
+ memory_limit_bytes_(GetMemoryLimitBytes()) {
if (port::Lowercase(platform_->Name()) == "cuda") {
platform_kind_ = PlatformKind::kCuda;
} else if (port::Lowercase(platform_->Name()) == "opencl") {
@@ -460,6 +472,14 @@ port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
}
void *StreamExecutor::Allocate(uint64 size) {
+ if (memory_limit_bytes_ > 0 &&
+ mem_alloc_bytes_ + size > memory_limit_bytes_) {
+ LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
+ << device_ordinal_
+ << " within provided limit. [used=" << mem_alloc_bytes_
+ << ", limit=" << memory_limit_bytes_ << "]";
+ return nullptr;
+ }
void *buf = implementation_->Allocate(size);
VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
<< buf << StackTraceIfVLOG10();
@@ -779,6 +799,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
mutex_lock lock(mu_);
mem_allocs_[opaque] = AllocRecord{
bytes, ""};
+ mem_alloc_bytes_ += bytes;
}
}
@@ -789,6 +810,7 @@ void StreamExecutor::EraseAllocRecord(void *opaque) {
LOG(ERROR) << "Deallocating unknown pointer: "
<< port::Printf("0x%p", opaque);
} else {
+ mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
mem_allocs_.erase(opaque);
}
}