aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor/host
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-07-31 12:58:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-31 13:02:31 -0700
commit122750a879f508342b53ee2eaecfde1656981280 (patch)
tree71890cf506fa2d80c0b6a521e7249d7167403731 /tensorflow/stream_executor/host
parent7ebed6678c2f378a5e28c8bade99866718dbdc7e (diff)
[SE] Make ExecutorCache thread-safe, change ExecutorCache::Insert to ExecutorCache::GetOrCreate. Add support for creating Executors for different device ordinals in parallel.
[XLA] Create Executors in parallel. PiperOrigin-RevId: 163734988
Diffstat (limited to 'tensorflow/stream_executor/host')
-rw-r--r--tensorflow/stream_executor/host/host_platform.cc19
-rw-r--r--tensorflow/stream_executor/host/host_platform.h3
2 files changed, 2 insertions, 20 deletions
diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc
index e93ccff4d8..2cb7d36967 100644
--- a/tensorflow/stream_executor/host/host_platform.cc
+++ b/tensorflow/stream_executor/host/host_platform.cc
@@ -63,23 +63,8 @@ port::StatusOr<StreamExecutor*> HostPlatform::ExecutorForDeviceWithPluginConfig(
port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor(
const StreamExecutorConfig& config) {
- mutex_lock lock(executors_mutex_);
-
- port::StatusOr<StreamExecutor*> status = executor_cache_.Get(config);
- if (status.ok()) {
- return status.ValueOrDie();
- }
-
- port::StatusOr<std::unique_ptr<StreamExecutor>> executor =
- GetUncachedExecutor(config);
- if (!executor.ok()) {
- return executor.status();
- }
-
- StreamExecutor* naked_executor = executor.ValueOrDie().get();
- SE_RETURN_IF_ERROR(
- executor_cache_.Insert(config, executor.ConsumeValueOrDie()));
- return naked_executor;
+ return executor_cache_.GetOrCreate(
+ config, [&]() { return GetUncachedExecutor(config); });
}
port::StatusOr<std::unique_ptr<StreamExecutor>>
diff --git a/tensorflow/stream_executor/host/host_platform.h b/tensorflow/stream_executor/host/host_platform.h
index 86805ef3e3..0faec6c8b7 100644
--- a/tensorflow/stream_executor/host/host_platform.h
+++ b/tensorflow/stream_executor/host/host_platform.h
@@ -72,9 +72,6 @@ class HostPlatform : public Platform {
// This platform's name.
string name_;
- // mutex that guards the ordinal-to-executor map.
- mutable mutex executors_mutex_;
-
// Cache of created StreamExecutors.
ExecutorCache executor_cache_;