diff options
author | 2017-07-21 15:30:28 -0700 | |
---|---|---|
committer | 2017-07-21 15:34:22 -0700 | |
commit | 10b6f290b71e684f85afc1790696d5032f29ed40 (patch) | |
tree | 4b654154028efe90b66c231addb7ea40ad59e45a | |
parent | 9513728ada1bb5aa571e477cb777027efb41b0fe (diff) |
Properly schedule merge nodes.
PiperOrigin-RevId: 162792987
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/grappler/costs/virtual_scheduler_test.cc | 625 |
2 files changed, 532 insertions, 99 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 8b51bb9096..5c1d85d749 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -516,7 +516,11 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) { for (auto* output_node : port_num_output_pair.second) { auto& output_state = node_map_[output_node]; output_state.num_inputs_ready++; - if (output_state.num_inputs_ready == output_state.inputs.size()) { + // Execute a node as soon as all its inputs are ready. Merge nodes are + // special since they run as soon as one of their inputs becomes + // available. + if (output_state.num_inputs_ready == output_state.inputs.size() || + IsMerge(*output_node)) { // This output node is now ready. output_state.time_ready = curr_time; ready_nodes_->AddNode(output_node); diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 29e3db1b74..fa24b58504 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -83,22 +83,19 @@ class VirtualSchedulerTest : public ::testing::Test { // Three Conv2Ds with only two in fetch nodes. void CreateGrapplerItemWithConv2Ds() { - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); - auto x = tensorflow::ops::RandomUniform( + Scope s = Scope::NewRootScope().WithDevice(kCPU0); + auto x = ops::RandomUniform( s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); - auto y = tensorflow::ops::RandomUniform( + auto y = ops::RandomUniform( s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); - auto z = tensorflow::ops::RandomUniform( + auto z = ops::RandomUniform( s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); - auto f = tensorflow::ops::RandomUniform( + auto f = ops::RandomUniform( s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); std::vector<int> strides = {1, 1, 1, 1}; - auto c0 = - tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME"); - auto c1 = - tensorflow::ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME"); - auto c2 = - tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME"); + auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME"); + auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME"); + auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME"); GraphDef def; TF_CHECK_OK(s.ToGraphDef(&def)); @@ -113,13 +110,13 @@ class VirtualSchedulerTest : public ::testing::Test { // A Conv2D with a variable. void CreateGrapplerItemWithConv2DAndVariable() { - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); - auto x = tensorflow::ops::RandomUniform( + Scope s = Scope::NewRootScope().WithDevice(kCPU0); + auto x = ops::RandomUniform( s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); - auto f = tensorflow::ops::Variable( - s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); + auto f = ops::Variable(s.WithOpName("f"), + {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT); std::vector<int> strides = {1, 1, 1, 1}; - auto y = tensorflow::ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME"); + auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME"); GraphDef def; TF_CHECK_OK(s.ToGraphDef(&def)); @@ -132,25 +129,23 @@ class VirtualSchedulerTest : public ::testing::Test { } void CreateGrapplerItemWithMatmulChain() { - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + Scope s = Scope::NewRootScope().WithDevice(kCPU0); // Add control dependencies to ensure tests do not rely on specific // manager and the order remains consistent for the test. - auto a = tensorflow::ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, - DT_FLOAT); - auto b = tensorflow::ops::RandomUniform( - s.WithOpName("b").WithControlDependencies(a), {3200, 3200}, DT_FLOAT); - auto c = tensorflow::ops::RandomUniform( - s.WithOpName("c").WithControlDependencies(b), {3200, 3200}, DT_FLOAT); - auto d = tensorflow::ops::RandomUniform( - s.WithOpName("d").WithControlDependencies(c), {3200, 3200}, DT_FLOAT); - auto e = tensorflow::ops::RandomUniform( - s.WithOpName("e").WithControlDependencies(d), {3200, 3200}, DT_FLOAT); - - auto ab = tensorflow::ops::MatMul( - s.WithOpName("ab").WithControlDependencies(e), a, b); - auto abc = tensorflow::ops::MatMul(s.WithOpName("abc"), ab, c); - auto abcd = tensorflow::ops::MatMul(s.WithOpName("abcd"), abc, d); - auto abcde = tensorflow::ops::MatMul(s.WithOpName("abcde"), abcd, e); + auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT); + auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a), + {3200, 3200}, DT_FLOAT); + auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b), + {3200, 3200}, DT_FLOAT); + auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c), + {3200, 3200}, DT_FLOAT); + auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d), + {3200, 3200}, DT_FLOAT); + + auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b); + auto abc = ops::MatMul(s.WithOpName("abc"), ab, c); + auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d); + auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e); GraphDef def; TF_CHECK_OK(s.ToGraphDef(&def)); @@ -168,17 +163,13 @@ class VirtualSchedulerTest : public ::testing::Test { // AddN that takes 4 tensors with 10x10x10x10. void CreateGrapplerItemWithAddN() { - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); - auto x = tensorflow::ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, - DT_FLOAT); - auto y = tensorflow::ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, - DT_FLOAT); - auto z = tensorflow::ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, - DT_FLOAT); - auto w = tensorflow::ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, - DT_FLOAT); - tensorflow::OutputList input_tensors = {x, y, z, w}; - auto out = tensorflow::ops::AddN(s.WithOpName("out"), input_tensors); + Scope s = Scope::NewRootScope().WithDevice(kCPU0); + auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT); + auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT); + auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT); + auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT); + OutputList input_tensors = {x, y, z, w}; + auto out = ops::AddN(s.WithOpName("out"), input_tensors); GraphDef def; TF_CHECK_OK(s.ToGraphDef(&def)); @@ -192,15 +183,15 @@ class VirtualSchedulerTest : public ::testing::Test { // NoOp that takes 7 NoOps as control dependency. void CreateGrapplerItemWithControlDependency() { - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + Scope s = Scope::NewRootScope().WithDevice(kCPU0); std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"}; - std::vector<tensorflow::Operation> input_tensors; + std::vector<Operation> input_tensors; for (const auto& input : input_noop_names) { - auto x = tensorflow::ops::NoOp(s.WithOpName(input)); + auto x = ops::NoOp(s.WithOpName(input)); input_tensors.push_back(x.operation); } - auto out = tensorflow::ops::NoOp( - s.WithControlDependencies(input_tensors).WithOpName("out")); + auto out = + ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out")); GraphDef def; TF_CHECK_OK(s.ToGraphDef(&def)); @@ -215,33 +206,33 @@ class VirtualSchedulerTest : public ::testing::Test { // FusedBN [an op with multiple outputs] with multiple consumers (including // control dependency). void CreateGrapplerItemWithBatchNorm() { - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); - auto x = tensorflow::ops::RandomUniform( + Scope s = Scope::NewRootScope().WithDevice(kCPU0); + auto x = ops::RandomUniform( s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); - auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"), - {depth_in_}, DT_FLOAT); - auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"), - {depth_in_}, DT_FLOAT); - auto mean = - tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); - auto var = - tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); - - auto batch_norm = tensorflow::ops::FusedBatchNorm( + auto scale = + ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT); + auto offset = + ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT); + auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); + auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); + + auto batch_norm = ops::FusedBatchNorm( s.WithOpName("bn"), x, scale, offset, mean, var, ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f)); auto y = batch_norm.y; auto batch_mean = batch_norm.batch_mean; auto batch_var = batch_norm.batch_variance; - auto z1 = tensorflow::ops::Add(s.WithOpName("z1"), x, y); - auto z2 = tensorflow::ops::Add(s.WithOpName("z2"), batch_var, batch_var); - auto z3 = tensorflow::ops::Add(s.WithOpName("z3"), batch_var, batch_var); - std::vector<tensorflow::Operation> input_tensors = { - batch_mean.op(), z1.z.op(), z2.z.op(), z3.z.op(), + auto z1 = ops::Add(s.WithOpName("z1"), x, y); + auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var); + auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var); + std::vector<Operation> input_tensors = { + batch_mean.op(), + z1.z.op(), + z2.z.op(), + z3.z.op(), }; - auto z4 = tensorflow::ops::NoOp( - s.WithControlDependencies(batch_var).WithOpName("z4")); + auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4")); GraphDef def; TF_CHECK_OK(s.ToGraphDef(&def)); @@ -261,21 +252,19 @@ class VirtualSchedulerTest : public ::testing::Test { // Create a fused bathcnorm on one device, and send the outputs to another // device. void CreateGrapplerItemWithInterDeviceTransfers() { - tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + Scope s = Scope::NewRootScope().WithDevice(kCPU0); // Create a FusedBatchNorm op that has multiple output ports. - auto x = tensorflow::ops::RandomUniform( + auto x = ops::RandomUniform( s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT); - auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"), - {depth_in_}, DT_FLOAT); - auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"), - {depth_in_}, DT_FLOAT); - auto mean = - tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); - auto var = - tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); - - auto batch_norm = tensorflow::ops::FusedBatchNorm( + auto scale = + ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT); + auto offset = + ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT); + auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT); + auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT); + + auto batch_norm = ops::FusedBatchNorm( s.WithOpName("bn"), x, scale, offset, mean, var, ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f)); auto y = batch_norm.y; @@ -284,20 +273,18 @@ class VirtualSchedulerTest : public ::testing::Test { // Copy FusedBatchNorm's outputs to CPU1. // y1 and y2 take the same tensor, so there should be only 1 Send and Recv. - auto y1 = - tensorflow::ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y); - auto y2 = - tensorflow::ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y); + auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y); + auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y); // batch_mean1 and batch_var1 take different output ports, so each will // initiate Send/Recv. - auto batch_mean1 = tensorflow::ops::Identity( + auto batch_mean1 = ops::Identity( s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean); - auto batch_var1 = tensorflow::ops::Identity( - s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var); + auto batch_var1 = + ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var); // This is control dependency. - auto control_dep = tensorflow::ops::NoOp(s.WithOpName("control_dep") - .WithControlDependencies(y) - .WithDevice(kCPU1)); + auto control_dep = ops::NoOp(s.WithOpName("control_dep") + .WithControlDependencies(y) + .WithDevice(kCPU1)); GraphDef def; TF_CHECK_OK(s.ToGraphDef(&def)); @@ -316,6 +303,402 @@ class VirtualSchedulerTest : public ::testing::Test { dependency_["control_dep"] = {"bn"}; } + // A simple while loop + void CreateGrapplerItemWithLoop() { + // Test graph produced in python using: + /* + with tf.Graph().as_default(): + i0 = tf.constant(0) + m0 = tf.ones([2, 2]) + c = lambda i, m: i < 10 + b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] + r = tf.while_loop( + c, b, loop_vars=[i0, m0], + shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) + with open('/tmp/graph.pbtxt', 'w') as f: + f.write(str(tf.get_default_graph().as_graph_def())) + */ + const string gdef_ascii = R"EOF( +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "ones" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + dim { + size: 2 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "while/Enter" + op: "Enter" + input: "Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/Enter_1" + op: "Enter" + input: "ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/Merge" + op: "Merge" + input: "while/Enter" + input: "while/NextIteration" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Merge_1" + op: "Merge" + input: "while/Enter_1" + input: "while/NextIteration_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/Less/y" + op: "Const" + input: "^while/Merge" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } +} +node { + name: "while/Less" + op: "Less" + input: "while/Merge" + input: "while/Less/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/LoopCond" + op: "LoopCond" + input: "while/Less" +} +node { + name: "while/Switch" + op: "Switch" + input: "while/Merge" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge" + } + } + } +} +node { + name: "while/Switch_1" + op: "Switch" + input: "while/Merge_1" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge_1" + } + } + } +} +node { + name: "while/Identity" + op: "Identity" + input: "while/Switch:1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Identity_1" + op: "Identity" + input: "while/Switch_1:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/add/y" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "while/add" + op: "Add" + input: "while/Identity" + input: "while/add/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/concat/axis" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "while/concat" + op: "ConcatV2" + input: "while/Identity_1" + input: "while/Identity_1" + input: "while/concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "while/NextIteration" + op: "NextIteration" + input: "while/add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/NextIteration_1" + op: "NextIteration" + input: "while/concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/Exit" + op: "Exit" + input: "while/Switch" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Exit_1" + op: "Exit" + input: "while/Switch_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +versions { + producer: 21 +} + )EOF"; + + grappler_item_.reset(new GrapplerItem); + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, + &grappler_item_->graph)); + grappler_item_->id = "test_graph"; + grappler_item_->fetch = {"while/Exit", "while/Exit_1"}; + } + // Call this after creating grappler_item_ and setting up dependency_. void InitScheduler() { scheduler_.reset(new TestVirtualScheduler( @@ -329,9 +712,10 @@ class VirtualSchedulerTest : public ::testing::Test { int64 exec_cost = 0; if (info.op_info.op() == "MatMul") { exec_cost = 2000000000; - } - if (info.op_info.op() == "RandomUniform") { + } else if (info.op_info.op() == "RandomUniform") { exec_cost = 1000000000; + } else { + exec_cost = 1000; } c.execution_time = Costs::NanoSeconds(exec_cost); return c; @@ -613,10 +997,11 @@ TEST_F(VirtualSchedulerTest, SummaryCostTest) { auto ops_executed = RunScheduler(""); Costs c = scheduler_->Summary(); - // RandomUniform - 5 - // Matmuls - 4 * 2 = 8 - // Total: 13 - EXPECT_EQ(13000000, c.execution_time.asMicroSeconds().count()); + // RandomUniform - 5 * 1s + // Matmuls - 4 * 2s = 8 + // Misc - 5 * 1us + // Total: 13000005 + EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count()); } // Like the above SummaryCostTest, but makes sure the stepstats timeline is @@ -629,7 +1014,7 @@ TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) { RunMetadata metadata; Costs c = scheduler_->Summary(&metadata); StepStats stepstats = metadata.step_stats(); - EXPECT_EQ(13000000, c.execution_time.asMicroSeconds().count()); + EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count()); // Should only be 1 device! EXPECT_EQ(1, stepstats.dev_stats().size()); @@ -661,7 +1046,7 @@ TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) { } // The base start_time is the time to compute RandomUniforms - int64 cur_time = static_cast<int64>(5000000); + int64 cur_time = static_cast<int64>(5000005); // The increment is the execution time of one matmul. See // CreateGrapplerItemWithMatmulChain for details. int64 increment = static_cast<int64>(2000000); @@ -935,5 +1320,49 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { EXPECT_EQ(get_output_size(recv_op_names[-1]), 4); EXPECT_EQ(get_output_size(send_op_names[-1]), 4); } + +TEST_F(VirtualSchedulerTest, WhileLoop) { + // Init. + CreateGrapplerItemWithLoop(); + InitScheduler(); + + // Run the scheduler. + RunScheduler(""); + + // Check the timeline + RunMetadata metadata; + scheduler_->Summary(&metadata); + + int num_next_iteration = 0; + int num_next_iteration_1 = 0; + int num_exit = 0; + int num_exit_1 = 0; + for (const auto& device_step_stats : metadata.step_stats().dev_stats()) { + for (const auto& stats : device_step_stats.node_stats()) { + std::cout << stats.DebugString() << std::endl; + if (stats.node_name() == "while/NextIteration") { + ++num_next_iteration; + EXPECT_EQ(19, stats.all_start_micros()); + } else if (stats.node_name() == "while/NextIteration_1") { + ++num_next_iteration_1; + EXPECT_EQ(20, stats.all_start_micros()); + } else if (stats.node_name() == "while/Exit") { + ++num_exit; + EXPECT_EQ(14, stats.all_start_micros()); + } else if (stats.node_name() == "while/Exit_1") { + ++num_exit_1; + EXPECT_EQ(12, stats.all_start_micros()); + } + } + } + + // Make sure we went though the body of the loop once, and that the output of + // the loop was scheduled as well. + EXPECT_EQ(1, num_next_iteration); + EXPECT_EQ(1, num_next_iteration_1); + EXPECT_EQ(1, num_exit); + EXPECT_EQ(1, num_exit_1); +} + } // end namespace grappler } // end namespace tensorflow |