aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/distributed_runtime
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-03 15:50:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-03 15:53:59 -0700
commit1737b824dad3af5681adb76585235ca96c2d57e2 (patch)
treec2235b91601d86b93d94aa1b52a79ff0a5f925a9 /tensorflow/core/distributed_runtime
parenta1810b42ea479ac1cee638a8507d52fc5dc61a3c (diff)
Add distributed model GetStepSequenceAsync implementation to
distributed_runtiume/RpcCollectiveExecutorMgr. In a distributed environment WorkerInterface is going to call this method at the group leader when fielding a GetStepSequence request from one of the other workers. PiperOrigin-RevId: 203196543
Diffstat (limited to 'tensorflow/core/distributed_runtime')
-rw-r--r--tensorflow/core/distributed_runtime/BUILD2
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc26
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc47
4 files changed, 81 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 75f8a19e9c..693b6b205f 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -494,9 +494,11 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:session_options",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:worker_proto_cc",
],
)
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
index 5eeed6e382..45b989f6e2 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.cc
@@ -99,6 +99,32 @@ void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
}
}
+void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
+ const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
+ const StatusCallback& done) {
+ if (!group_leader_.empty()) {
+ LOG(ERROR) << "GetStepSequence called at non-group-leader";
+ done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
+ } else {
+ mutex_lock l(sequence_mu_);
+ for (int64 graph_key : request->graph_key()) {
+ auto it = sequence_table_.find(graph_key);
+ GraphKeySequence* gks = nullptr;
+ if (it == sequence_table_.end()) {
+ gks = new GraphKeySequence(graph_key);
+ gks->next_step_id_ = NewRandomStepId();
+ sequence_table_[graph_key] = gks;
+ } else {
+ gks = it->second;
+ }
+ StepSequence* ss = response->add_step_sequence();
+ ss->set_graph_key(graph_key);
+ ss->set_next_step_id(gks->next_step_id_);
+ }
+ done(Status::OK());
+ }
+}
+
Status RpcCollectiveExecutorMgr::UpdateStepSequences(
const GetStepSequenceResponse& resp) {
mutex_lock l(sequence_mu_);
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
index e9f3f0ebe8..c9581fa00f 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h
@@ -42,6 +42,12 @@ class RpcCollectiveExecutorMgr : public CollectiveExecutorMgr {
virtual ~RpcCollectiveExecutorMgr();
+ // This function should only be called at the group_leader, by an RPC.
+ // Other needs for StepIds should be satisfied by NextStepId.
+ void GetStepSequenceAsync(const GetStepSequenceRequest* request,
+ GetStepSequenceResponse* response,
+ const StatusCallback& done) override;
+
void RefreshStepIdSequenceAsync(int64 graph_key,
const StatusCallback& done) override;
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
index 37b83d82be..0323300fdd 100644
--- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@@ -121,4 +122,50 @@ TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) {
EXPECT_GT(llabs(y - z), 3);
}
+TEST_F(RpcCollectiveExecutorMgrTest, GetStepSequence) {
+ int64 x = cme_->NextStepId(3);
+ EXPECT_EQ(x, CollectiveExecutor::kInvalidId);
+ int64 y = cme_->NextStepId(4);
+ EXPECT_EQ(y, CollectiveExecutor::kInvalidId);
+ GetStepSequenceRequest request;
+ GetStepSequenceResponse response;
+ request.add_graph_key(3);
+ request.add_graph_key(4);
+ {
+ Notification note;
+ Status status;
+ cme_->GetStepSequenceAsync(&request, &response,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ ASSERT_EQ(2, response.step_sequence_size());
+ std::unordered_map<int64, int64> values;
+ for (const auto& ss : response.step_sequence()) {
+ values[ss.graph_key()] = ss.next_step_id();
+ }
+ EXPECT_NE(values[3], CollectiveExecutor::kInvalidId);
+ EXPECT_NE(values[4], CollectiveExecutor::kInvalidId);
+ // Re-get, should be same values.
+ response.Clear();
+ {
+ Notification note;
+ Status status;
+ cme_->GetStepSequenceAsync(&request, &response,
+ [this, &status, &note](const Status& s) {
+ status = s;
+ note.Notify();
+ });
+ note.WaitForNotification();
+ EXPECT_TRUE(status.ok());
+ }
+ ASSERT_EQ(2, response.step_sequence_size());
+ for (const auto& ss : response.step_sequence()) {
+ EXPECT_EQ(values[ss.graph_key()], ss.next_step_id());
+ }
+}
+
} // namespace tensorflow