aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-09-12 12:29:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 12:38:24 -0700
commit28e945e590b07de137f318a70896bc4fc31f7053 (patch)
tree59435a08e14420e284ac1c0aed60074bd2a71435 /tensorflow/stream_executor
parentf337425dc71e3ea95aa91ce401a40c1b594486ca (diff)
Internal change.
PiperOrigin-RevId: 212684548
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc24
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h7
2 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 9515d8e62a..10bf006787 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <atomic>
#include <utility>
+#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/lib/env.h"
@@ -163,6 +164,15 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind,
CheckPlatformKindIsValid(platform_kind);
}
+// Get per-device memory limit in bytes. Returns 0 if
+// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
+static int64 GetMemoryLimitBytes() {
+ int64 value;
+ SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
+ 0, &value));
+ return value * (1ll << 20);
+}
+
StreamExecutor::StreamExecutor(
const Platform *platform,
std::unique_ptr<internal::StreamExecutorInterface> implementation)
@@ -172,7 +182,9 @@ StreamExecutor::StreamExecutor(
background_threads_(new port::ThreadPool(
port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
live_stream_count_(0),
- tracing_enabled_(false) {
+ tracing_enabled_(false),
+ mem_alloc_bytes_(0),
+ memory_limit_bytes_(GetMemoryLimitBytes()) {
if (port::Lowercase(platform_->Name()) == "cuda") {
platform_kind_ = PlatformKind::kCuda;
} else if (port::Lowercase(platform_->Name()) == "opencl") {
@@ -460,6 +472,14 @@ port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
}
void *StreamExecutor::Allocate(uint64 size) {
+ if (memory_limit_bytes_ > 0 &&
+ mem_alloc_bytes_ + size > memory_limit_bytes_) {
+ LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
+ << device_ordinal_
+ << " within provided limit. [used=" << mem_alloc_bytes_
+ << ", limit=" << memory_limit_bytes_ << "]";
+ return nullptr;
+ }
void *buf = implementation_->Allocate(size);
VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
<< buf << StackTraceIfVLOG10();
@@ -779,6 +799,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
mutex_lock lock(mu_);
mem_allocs_[opaque] = AllocRecord{
bytes, ""};
+ mem_alloc_bytes_ += bytes;
}
}
@@ -789,6 +810,7 @@ void StreamExecutor::EraseAllocRecord(void *opaque) {
LOG(ERROR) << "Deallocating unknown pointer: "
<< port::Printf("0x%p", opaque);
} else {
+ mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
mem_allocs_.erase(opaque);
}
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 437f298616..d04025b681 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -699,6 +699,13 @@ class StreamExecutor {
// The set of TraceListeners registered for this StreamExecutor.
std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
+ // Allocated memory in bytes.
+ int64 mem_alloc_bytes_;
+
+ // Memory limit in bytes. Value less or equal to 0 indicates there is no
+ // limit.
+ int64 memory_limit_bytes_;
+
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
};