diff options
author | 2018-05-01 13:15:53 -0700 | |
---|---|---|
committer | 2018-05-01 13:19:40 -0700 | |
commit | 9149558a639efe82baf1b5201feccf2411343a8a (patch) | |
tree | 1a6d3648dc5c2c59a00ca37c0f72c4eee81cc378 /tensorflow/core/distributed_runtime/device_resolver_distributed.h | |
parent | 1a50cd4ca8c4fe1c1a9ea14f219fd98be8704a7d (diff) |
Collective Ops Part 5
Distributed-mode implementations of DeviceResolverInterface
and ParamResolverInterface. Extend Worker interface with
new methods in support of these interfaces.
This change is part of a series of changes introducing infrastructure
for collective ops and initial implementations of reduction and broadcast.
PiperOrigin-RevId: 194984585
Diffstat (limited to 'tensorflow/core/distributed_runtime/device_resolver_distributed.h')
-rw-r--r-- | tensorflow/core/distributed_runtime/device_resolver_distributed.h | 67 |
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/tensorflow/core/distributed_runtime/device_resolver_distributed.h new file mode 100644 index 0000000000..ac68ec6873 --- /dev/null +++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.h @@ -0,0 +1,67 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_ +#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_ + +#include <string> +#include <vector> + +#include "tensorflow/core/framework/collective.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace tensorflow { +class DeviceMgr; +class WorkerCacheInterface; + +class DeviceResolverDistributed : public DeviceResolverInterface { + public: + DeviceResolverDistributed(const DeviceMgr* dev_mgr, + WorkerCacheInterface* worker_cache, + const string& task_name); + + virtual ~DeviceResolverDistributed() {} + + void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params, + std::vector<DeviceLocality>* localities, + const StatusCallback& done) override; + + void GetLocalityAsync(const string& device, const string& task, + DeviceLocality* locality, + const StatusCallback& done) override; + + void ClearTask(const string& task) override; + + protected: + // Loads attr_table_ with device attributes retrieved from remote task. + void RefreshRemoteAttributes(const string& device, const string& task, + const StatusCallback& done) LOCKS_EXCLUDED(mu_); + + // Subroutine used by GetDeviceLocalitiesAsync. Recursively extends + // *localities with DeviceLocality of the corresponding device named + // by inst_params.instance.device_names. + void GetDeviceLocalitiesRecursive(const CollInstanceParams& inst_params, + std::vector<DeviceLocality>* localities, + const StatusCallback& done); + + const DeviceMgr* dev_mgr_; // Not owned + WorkerCacheInterface* worker_cache_; // Not owned + const string task_name_; + mutex mu_; + gtl::FlatMap<string, DeviceAttributes> attr_table_ GUARDED_BY(mu_); +}; + +} // namespace tensorflow +#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_ |