aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/platform_util.cc
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2017-07-31 12:58:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-31 13:02:31 -0700
commit122750a879f508342b53ee2eaecfde1656981280 (patch)
tree71890cf506fa2d80c0b6a521e7249d7167403731 /tensorflow/compiler/xla/service/platform_util.cc
parent7ebed6678c2f378a5e28c8bade99866718dbdc7e (diff)
[SE] Make ExecutorCache thread-safe, change ExecutorCache::Insert to ExecutorCache::GetOrCreate. Add support for creating Executors for different device ordinals in parallel.
[XLA] Create Executors in parallel. PiperOrigin-RevId: 163734988
Diffstat (limited to 'tensorflow/compiler/xla/service/platform_util.cc')
-rw-r--r--tensorflow/compiler/xla/service/platform_util.cc38
1 files changed, 25 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc
index 116bd3f067..4f915a0c2e 100644
--- a/tensorflow/compiler/xla/service/platform_util.cc
+++ b/tensorflow/compiler/xla/service/platform_util.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -140,21 +141,32 @@ PlatformUtil::GetStreamExecutors(se::Platform* platform) {
device_count = 1;
}
std::vector<se::StreamExecutor*> stream_executors(device_count, nullptr);
- for (int i = 0; i < device_count; ++i) {
- se::StreamExecutorConfig config;
- config.ordinal = i;
- auto executor_status = platform->GetExecutor(config);
- if (executor_status.ok()) {
- se::StreamExecutor* executor = executor_status.ValueOrDie();
- if (IsDeviceSupported(executor)) {
- stream_executors[i] = executor;
- }
- } else {
- LOG(WARNING) << "unable to create StreamExecutor for " << platform->Name()
- << ":" << i << ": "
- << executor_status.status().error_message();
+ VLOG(1) << "Initializing devices";
+ {
+ tensorflow::thread::ThreadPool thread_pool(
+ tensorflow::Env::Default(), "device_initialization", device_count);
+ for (int i = 0; i < device_count; ++i) {
+ thread_pool.Schedule([platform, i, &stream_executors]() {
+ VLOG(1) << "Started device init " << i;
+ se::StreamExecutorConfig config;
+ config.ordinal = i;
+ auto executor_status = platform->GetExecutor(config);
+ if (executor_status.ok()) {
+ se::StreamExecutor* executor = executor_status.ValueOrDie();
+ if (IsDeviceSupported(executor)) {
+ stream_executors[i] = executor;
+ }
+ } else {
+ LOG(WARNING) << "unable to create StreamExecutor for "
+ << platform->Name() << ":" << i << ": "
+ << executor_status.status().error_message();
+ }
+ VLOG(1) << "Finished device init " << i;
+ });
}
+ // Block here in thread_pool destructor until all devices are initialized.
}
+ VLOG(1) << "Device initialization complete";
if (std::all_of(stream_executors.begin(), stream_executors.end(),
[](se::StreamExecutor* s) { return s == nullptr; })) {
return InternalError("no supported devices found for platform %s",