aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/device_resolver_distributed.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-01 13:15:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-01 13:19:40 -0700
commit9149558a639efe82baf1b5201feccf2411343a8a (patch)
tree1a6d3648dc5c2c59a00ca37c0f72c4eee81cc378 /tensorflow/core/distributed_runtime/device_resolver_distributed.h
parent1a50cd4ca8c4fe1c1a9ea14f219fd98be8704a7d (diff)
Collective Ops Part 5
Distributed-mode implementations of DeviceResolverInterface and ParamResolverInterface. Extend Worker interface with new methods in support of these interfaces. This change is part of a series of changes introducing infrastructure for collective ops and initial implementations of reduction and broadcast. PiperOrigin-RevId: 194984585
Diffstat (limited to 'tensorflow/core/distributed_runtime/device_resolver_distributed.h')
-rw-r--r--tensorflow/core/distributed_runtime/device_resolver_distributed.h67
1 files changed, 67 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/device_resolver_distributed.h b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
new file mode 100644
index 0000000000..ac68ec6873
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/device_resolver_distributed.h
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
+
+namespace tensorflow {
+class DeviceMgr;
+class WorkerCacheInterface;
+
+class DeviceResolverDistributed : public DeviceResolverInterface {
+ public:
+ DeviceResolverDistributed(const DeviceMgr* dev_mgr,
+ WorkerCacheInterface* worker_cache,
+ const string& task_name);
+
+ virtual ~DeviceResolverDistributed() {}
+
+ void GetDeviceLocalitiesAsync(const CollInstanceParams& inst_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done) override;
+
+ void GetLocalityAsync(const string& device, const string& task,
+ DeviceLocality* locality,
+ const StatusCallback& done) override;
+
+ void ClearTask(const string& task) override;
+
+ protected:
+ // Loads attr_table_ with device attributes retrieved from remote task.
+ void RefreshRemoteAttributes(const string& device, const string& task,
+ const StatusCallback& done) LOCKS_EXCLUDED(mu_);
+
+ // Subroutine used by GetDeviceLocalitiesAsync. Recursively extends
+ // *localities with DeviceLocality of the corresponding device named
+ // by inst_params.instance.device_names.
+ void GetDeviceLocalitiesRecursive(const CollInstanceParams& inst_params,
+ std::vector<DeviceLocality>* localities,
+ const StatusCallback& done);
+
+ const DeviceMgr* dev_mgr_; // Not owned
+ WorkerCacheInterface* worker_cache_; // Not owned
+ const string task_name_;
+ mutex mu_;
+ gtl::FlatMap<string, DeviceAttributes> attr_table_ GUARDED_BY(mu_);
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_DEVICE_RESOLVER_DISTRIBUTED_H_