aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime/test_utils.h
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-08-21 08:18:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 08:25:13 -0700
commit5b456c9ab567c7d9262c57f32693ff33a87946e6 (patch)
treebb2b74a5b11ff04e9b3319ca017e53b1b48f5c40 /tensorflow/core/distributed_runtime/test_utils.h
parentaeab291563b0b4cc75c0f5fc73610a6595780570 (diff)
[Distributed] Add methods to WorkerCache that selectively list workers by job name.
PiperOrigin-RevId: 209597829
Diffstat (limited to 'tensorflow/core/distributed_runtime/test_utils.h')
-rw-r--r--tensorflow/core/distributed_runtime/test_utils.h14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h
index 48d83845dd..88a97da34d 100644
--- a/tensorflow/core/distributed_runtime/test_utils.h
+++ b/tensorflow/core/distributed_runtime/test_utils.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ workers->clear();
+ for (auto it : workers_) {
+ DeviceNameUtils::ParsedName device_name;
+ CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
+ CHECK(device_name.has_job);
+ if (job_name == device_name.job) {
+ workers->push_back(it.first);
+ }
+ }
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
auto it = workers_.find(target);
if (it != workers_.end()) {