diff options
Diffstat (limited to 'tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc')
-rw-r--r-- | tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc index 37b83d82be..0323300fdd 100644 --- a/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc +++ b/tensorflow/core/distributed_runtime/rpc_collective_executor_mgr_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/worker.pb.h" #include "tensorflow/core/public/session_options.h" namespace tensorflow { @@ -121,4 +122,50 @@ TEST_F(RpcCollectiveExecutorMgrTest, NextStepId) { EXPECT_GT(llabs(y - z), 3); } +TEST_F(RpcCollectiveExecutorMgrTest, GetStepSequence) { + int64 x = cme_->NextStepId(3); + EXPECT_EQ(x, CollectiveExecutor::kInvalidId); + int64 y = cme_->NextStepId(4); + EXPECT_EQ(y, CollectiveExecutor::kInvalidId); + GetStepSequenceRequest request; + GetStepSequenceResponse response; + request.add_graph_key(3); + request.add_graph_key(4); + { + Notification note; + Status status; + cme_->GetStepSequenceAsync(&request, &response, + [this, &status, ¬e](const Status& s) { + status = s; + note.Notify(); + }); + note.WaitForNotification(); + EXPECT_TRUE(status.ok()); + } + ASSERT_EQ(2, response.step_sequence_size()); + std::unordered_map<int64, int64> values; + for (const auto& ss : response.step_sequence()) { + values[ss.graph_key()] = ss.next_step_id(); + } + EXPECT_NE(values[3], CollectiveExecutor::kInvalidId); + EXPECT_NE(values[4], CollectiveExecutor::kInvalidId); + // Re-get, should be same values. + response.Clear(); + { + Notification note; + Status status; + cme_->GetStepSequenceAsync(&request, &response, + [this, &status, ¬e](const Status& s) { + status = s; + note.Notify(); + }); + note.WaitForNotification(); + EXPECT_TRUE(status.ok()); + } + ASSERT_EQ(2, response.step_sequence_size()); + for (const auto& ss : response.step_sequence()) { + EXPECT_EQ(values[ss.graph_key()], ss.next_step_id()); + } +} + } // namespace tensorflow |