diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-03 15:50:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-03 15:53:59 -0700 |
commit | 1737b824dad3af5681adb76585235ca96c2d57e2 (patch) | |
tree | c2235b91601d86b93d94aa1b52a79ff0a5f925a9 /tensorflow/core/distributed_runtime | |
parent | a1810b42ea479ac1cee638a8507d52fc5dc61a3c (diff) |
Add distributed model GetStepSequenceAsync implementation to
distributed_runtiume/RpcCollectiveExecutorMgr.
In a distributed environment WorkerInterface is going to call this
method at the group leader when fielding a GetStepSequence request
from one of the other workers.
PiperOrigin-RevId: 203196543
Diffstat (limited to 'tensorflow/core/distributed_runtime')
4 files changed, 81 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index 75f8a19e9c..693b6b205f 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -494,9 +494,11 @@ tf_cc_test( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:worker_proto_cc", ], ) diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc index 5eeed6e382..45b989f6e2 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc @@ -99,6 +99,32 @@ void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync( } } +void RpcCollectiveExecutorMgr::GetStepSequenceAsync( + const GetStepSequenceRequest* request, GetStepSequenceResponse* response, + const StatusCallback& done) { + if (!group_leader_.empty()) { + LOG(ERROR) << "GetStepSequence called at non-group-leader"; + done(errors::Internal("GetStepSequenceAsync called at non-group-leader")); + } else { + mutex_lock l(sequence_mu_); + for (int64 graph_key : request->graph_key()) { + auto it = sequence_table_.find(graph_key); + GraphKeySequence* gks = nullptr; + if (it == sequence_table_.end()) { + gks = new GraphKeySequence(graph_key); + gks->next_step_id_ = NewRandomStepId(); + sequence_table_[graph_key] = gks; + } else { + gks = it->second; + } + StepSequence* ss = response->add_step_sequence(); + ss->set_graph_key(graph_key); + ss->set_next_step_id(gks->next_step_id_); + } + done(Status::OK()); + } +} + Status RpcCollectiveExecutorMgr::UpdateStepSequences( const GetStepSequenceResponse& resp) { mutex_lock l(sequence_mu_); diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h index e9f3f0ebe8..c9581fa00f 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h @@ -42,6 +42,12 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr { virtual ~RpcCollectiveExecutorMgr(); + // This function should only be called at the group_leader, by an RPC. + // Other needs for StepIds should be satisfied by NextStepId. + void GetStepSequenceAsync(const GetStepSequenceRequest* request, + GetStepSequenceResponse* response, + const StatusCallback& done) override; + void RefreshStepIdSequenceAsync(int64 graph_key, const StatusCallback& done) override; diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc index 37b83d82be..0323300fdd 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -121,4 +122,50 @@ TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) { EXPECT_GT(llabs(y - z), 3); } +TEST_F(RpcCollectiveExecutorMgrTest, GetStepSequence) { + int64 x = cme_->NextStepId(3); + EXPECT_EQ(x, CollectiveExecutor::kInvalidId); + int64 y = cme_->NextStepId(4); + EXPECT_EQ(y, CollectiveExecutor::kInvalidId); + GetStepSequenceRequest request; + GetStepSequenceResponse response; + request.add_graph_key(3); + request.add_graph_key(4); + { + Notification note; + Status status; + cme_->GetStepSequenceAsync(&request, &response, + [this, &status, ¬e](const Status& s) { + status = s; + note.Notify(); + }); + note.WaitForNotification(); + EXPECT_TRUE(status.ok()); + } + ASSERT_EQ(2, response.step_sequence_size()); + std::unordered_map<int64, int64> values; + for (const auto& ss : response.step_sequence()) { + values[ss.graph_key()] = ss.next_step_id(); + } + EXPECT_NE(values[3], CollectiveExecutor::kInvalidId); + EXPECT_NE(values[4], CollectiveExecutor::kInvalidId); + // Re-get, should be same values. + response.Clear(); + { + Notification note; + Status status; + cme_->GetStepSequenceAsync(&request, &response, + [this, &status, ¬e](const Status& s) { + status = s; + note.Notify(); + }); + note.WaitForNotification(); + EXPECT_TRUE(status.ok()); + } + ASSERT_EQ(2, response.step_sequence_size()); + for (const auto& ss : response.step_sequence()) { + EXPECT_EQ(values[ss.graph_key()], ss.next_step_id()); + } +} + } // namespace tensorflow |