diff options
author | Peter Hawkins <phawkins@google.com> | 2017-07-31 12:58:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-07-31 13:02:31 -0700 |
commit | 122750a879f508342b53ee2eaecfde1656981280 (patch) | |
tree | 71890cf506fa2d80c0b6a521e7249d7167403731 /tensorflow/compiler/xla/service/platform_util.cc | |
parent | 7ebed6678c2f378a5e28c8bade99866718dbdc7e (diff) |
[SE] Make ExecutorCache thread-safe, change ExecutorCache::Insert to ExecutorCache::GetOrCreate. Add support for creating Executors for different device ordinals in parallel.
[XLA] Create Executors in parallel.
PiperOrigin-RevId: 163734988
Diffstat (limited to 'tensorflow/compiler/xla/service/platform_util.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/platform_util.cc | 38 |
1 files changed, 25 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 116bd3f067..4f915a0c2e 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -140,21 +141,32 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) { device_count = 1; } std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr); - for (int i = 0; i < device_count; ++i) { - se::StreamExecutorConfig config; - config.ordinal = i; - auto executor_status = platform->GetExecutor(config); - if (executor_status.ok()) { - se::StreamExecutor* executor = executor_status.ValueOrDie(); - if (IsDeviceSupported(executor)) { - stream_executors[i] = executor; - } - } else { - LOG(WARNING) << "unable to create StreamExecutor for " << platform->Name() - << ":" << i << ": " - << executor_status.status().error_message(); + VLOG(1) << "Initializing devices"; + { + tensorflow::thread::ThreadPool thread_pool( + tensorflow::Env::Default(), "device_initialization", device_count); + for (int i = 0; i < device_count; ++i) { + thread_pool.Schedule([platform, i, &stream_executors]() { + VLOG(1) << "Started device init " << i; + se::StreamExecutorConfig config; + config.ordinal = i; + auto executor_status = platform->GetExecutor(config); + if (executor_status.ok()) { + se::StreamExecutor* executor = executor_status.ValueOrDie(); + if (IsDeviceSupported(executor)) { + stream_executors[i] = executor; + } + } else { + LOG(WARNING) << "unable to create StreamExecutor for " + << platform->Name() << ":" << i << ": " + << executor_status.status().error_message(); + } + VLOG(1) << "Finished device init " << i; + }); } + // Block here in thread_pool destructor until all devices are initialized. } + VLOG(1) << "Device initialization complete"; if (std::all_of(stream_executors.begin(), stream_executors.end(), [](se::StreamExecutor* s) { return s == nullptr; })) { return InternalError("no supported devices found for platform %s", |