aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-07-21 15:30:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-21 15:34:22 -0700
commit10b6f290b71e684f85afc1790696d5032f29ed40 (patch)
tree4b654154028efe90b66c231addb7ea40ad59e45a
parent9513728ada1bb5aa571e477cb777027efb41b0fe (diff)
Properly schedule merge nodes.
PiperOrigin-RevId: 162792987
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc6
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc625
2 files changed, 532 insertions, 99 deletions
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 8b51bb9096..5c1d85d749 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -516,7 +516,11 @@ bool VirtualScheduler::MarkCurrNodeExecuted(const Costs& node_costs) {
for (auto* output_node : port_num_output_pair.second) {
auto& output_state = node_map_[output_node];
output_state.num_inputs_ready++;
- if (output_state.num_inputs_ready == output_state.inputs.size()) {
+ // Execute a node as soon as all its inputs are ready. Merge nodes are
+ // special since they run as soon as one of their inputs becomes
+ // available.
+ if (output_state.num_inputs_ready == output_state.inputs.size() ||
+ IsMerge(*output_node)) {
// This output node is now ready.
output_state.time_ready = curr_time;
ready_nodes_->AddNode(output_node);
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index 29e3db1b74..fa24b58504 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -83,22 +83,19 @@ class VirtualSchedulerTest : public ::testing::Test {
// Three Conv2Ds with only two in fetch nodes.
void CreateGrapplerItemWithConv2Ds() {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
- auto x = tensorflow::ops::RandomUniform(
+ Scope s = Scope::NewRootScope().WithDevice(kCPU0);
+ auto x = ops::RandomUniform(
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
- auto y = tensorflow::ops::RandomUniform(
+ auto y = ops::RandomUniform(
s.WithOpName("y"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
- auto z = tensorflow::ops::RandomUniform(
+ auto z = ops::RandomUniform(
s.WithOpName("z"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
- auto f = tensorflow::ops::RandomUniform(
+ auto f = ops::RandomUniform(
s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
std::vector<int> strides = {1, 1, 1, 1};
- auto c0 =
- tensorflow::ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
- auto c1 =
- tensorflow::ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
- auto c2 =
- tensorflow::ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
+ auto c0 = ops::Conv2D(s.WithOpName("c0"), x, f, strides, "SAME");
+ auto c1 = ops::Conv2D(s.WithOpName("c1"), y, f, strides, "SAME");
+ auto c2 = ops::Conv2D(s.WithOpName("c2"), z, f, strides, "SAME");
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
@@ -113,13 +110,13 @@ class VirtualSchedulerTest : public ::testing::Test {
// A Conv2D with a variable.
void CreateGrapplerItemWithConv2DAndVariable() {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
- auto x = tensorflow::ops::RandomUniform(
+ Scope s = Scope::NewRootScope().WithDevice(kCPU0);
+ auto x = ops::RandomUniform(
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
- auto f = tensorflow::ops::Variable(
- s.WithOpName("f"), {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
+ auto f = ops::Variable(s.WithOpName("f"),
+ {kernel_, kernel_, depth_in_, depth_out_}, DT_FLOAT);
std::vector<int> strides = {1, 1, 1, 1};
- auto y = tensorflow::ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
+ auto y = ops::Conv2D(s.WithOpName("y"), x, f, strides, "SAME");
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
@@ -132,25 +129,23 @@ class VirtualSchedulerTest : public ::testing::Test {
}
void CreateGrapplerItemWithMatmulChain() {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
+ Scope s = Scope::NewRootScope().WithDevice(kCPU0);
// Add control dependencies to ensure tests do not rely on specific
// manager and the order remains consistent for the test.
- auto a = tensorflow::ops::RandomUniform(s.WithOpName("a"), {3200, 3200},
- DT_FLOAT);
- auto b = tensorflow::ops::RandomUniform(
- s.WithOpName("b").WithControlDependencies(a), {3200, 3200}, DT_FLOAT);
- auto c = tensorflow::ops::RandomUniform(
- s.WithOpName("c").WithControlDependencies(b), {3200, 3200}, DT_FLOAT);
- auto d = tensorflow::ops::RandomUniform(
- s.WithOpName("d").WithControlDependencies(c), {3200, 3200}, DT_FLOAT);
- auto e = tensorflow::ops::RandomUniform(
- s.WithOpName("e").WithControlDependencies(d), {3200, 3200}, DT_FLOAT);
-
- auto ab = tensorflow::ops::MatMul(
- s.WithOpName("ab").WithControlDependencies(e), a, b);
- auto abc = tensorflow::ops::MatMul(s.WithOpName("abc"), ab, c);
- auto abcd = tensorflow::ops::MatMul(s.WithOpName("abcd"), abc, d);
- auto abcde = tensorflow::ops::MatMul(s.WithOpName("abcde"), abcd, e);
+ auto a = ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, DT_FLOAT);
+ auto b = ops::RandomUniform(s.WithOpName("b").WithControlDependencies(a),
+ {3200, 3200}, DT_FLOAT);
+ auto c = ops::RandomUniform(s.WithOpName("c").WithControlDependencies(b),
+ {3200, 3200}, DT_FLOAT);
+ auto d = ops::RandomUniform(s.WithOpName("d").WithControlDependencies(c),
+ {3200, 3200}, DT_FLOAT);
+ auto e = ops::RandomUniform(s.WithOpName("e").WithControlDependencies(d),
+ {3200, 3200}, DT_FLOAT);
+
+ auto ab = ops::MatMul(s.WithOpName("ab").WithControlDependencies(e), a, b);
+ auto abc = ops::MatMul(s.WithOpName("abc"), ab, c);
+ auto abcd = ops::MatMul(s.WithOpName("abcd"), abc, d);
+ auto abcde = ops::MatMul(s.WithOpName("abcde"), abcd, e);
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
@@ -168,17 +163,13 @@ class VirtualSchedulerTest : public ::testing::Test {
// AddN that takes 4 tensors with 10x10x10x10.
void CreateGrapplerItemWithAddN() {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
- auto x = tensorflow::ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10},
- DT_FLOAT);
- auto y = tensorflow::ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10},
- DT_FLOAT);
- auto z = tensorflow::ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10},
- DT_FLOAT);
- auto w = tensorflow::ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10},
- DT_FLOAT);
- tensorflow::OutputList input_tensors = {x, y, z, w};
- auto out = tensorflow::ops::AddN(s.WithOpName("out"), input_tensors);
+ Scope s = Scope::NewRootScope().WithDevice(kCPU0);
+ auto x = ops::RandomUniform(s.WithOpName("x"), {10, 10, 10, 10}, DT_FLOAT);
+ auto y = ops::RandomUniform(s.WithOpName("y"), {10, 10, 10, 10}, DT_FLOAT);
+ auto z = ops::RandomUniform(s.WithOpName("z"), {10, 10, 10, 10}, DT_FLOAT);
+ auto w = ops::RandomUniform(s.WithOpName("w"), {10, 10, 10, 10}, DT_FLOAT);
+ OutputList input_tensors = {x, y, z, w};
+ auto out = ops::AddN(s.WithOpName("out"), input_tensors);
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
@@ -192,15 +183,15 @@ class VirtualSchedulerTest : public ::testing::Test {
// NoOp that takes 7 NoOps as control dependency.
void CreateGrapplerItemWithControlDependency() {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
+ Scope s = Scope::NewRootScope().WithDevice(kCPU0);
std::vector<string> input_noop_names = {"x", "y", "z", "w", "u", "v", "t"};
- std::vector<tensorflow::Operation> input_tensors;
+ std::vector<Operation> input_tensors;
for (const auto& input : input_noop_names) {
- auto x = tensorflow::ops::NoOp(s.WithOpName(input));
+ auto x = ops::NoOp(s.WithOpName(input));
input_tensors.push_back(x.operation);
}
- auto out = tensorflow::ops::NoOp(
- s.WithControlDependencies(input_tensors).WithOpName("out"));
+ auto out =
+ ops::NoOp(s.WithControlDependencies(input_tensors).WithOpName("out"));
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
@@ -215,33 +206,33 @@ class VirtualSchedulerTest : public ::testing::Test {
// FusedBN [an op with multiple outputs] with multiple consumers (including
// control dependency).
void CreateGrapplerItemWithBatchNorm() {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
- auto x = tensorflow::ops::RandomUniform(
+ Scope s = Scope::NewRootScope().WithDevice(kCPU0);
+ auto x = ops::RandomUniform(
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
- auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"),
- {depth_in_}, DT_FLOAT);
- auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"),
- {depth_in_}, DT_FLOAT);
- auto mean =
- tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
- auto var =
- tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
-
- auto batch_norm = tensorflow::ops::FusedBatchNorm(
+ auto scale =
+ ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
+ auto offset =
+ ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
+ auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
+ auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
+
+ auto batch_norm = ops::FusedBatchNorm(
s.WithOpName("bn"), x, scale, offset, mean, var,
ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
auto y = batch_norm.y;
auto batch_mean = batch_norm.batch_mean;
auto batch_var = batch_norm.batch_variance;
- auto z1 = tensorflow::ops::Add(s.WithOpName("z1"), x, y);
- auto z2 = tensorflow::ops::Add(s.WithOpName("z2"), batch_var, batch_var);
- auto z3 = tensorflow::ops::Add(s.WithOpName("z3"), batch_var, batch_var);
- std::vector<tensorflow::Operation> input_tensors = {
- batch_mean.op(), z1.z.op(), z2.z.op(), z3.z.op(),
+ auto z1 = ops::Add(s.WithOpName("z1"), x, y);
+ auto z2 = ops::Add(s.WithOpName("z2"), batch_var, batch_var);
+ auto z3 = ops::Add(s.WithOpName("z3"), batch_var, batch_var);
+ std::vector<Operation> input_tensors = {
+ batch_mean.op(),
+ z1.z.op(),
+ z2.z.op(),
+ z3.z.op(),
};
- auto z4 = tensorflow::ops::NoOp(
- s.WithControlDependencies(batch_var).WithOpName("z4"));
+ auto z4 = ops::NoOp(s.WithControlDependencies(batch_var).WithOpName("z4"));
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
@@ -261,21 +252,19 @@ class VirtualSchedulerTest : public ::testing::Test {
// Create a fused bathcnorm on one device, and send the outputs to another
// device.
void CreateGrapplerItemWithInterDeviceTransfers() {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0);
+ Scope s = Scope::NewRootScope().WithDevice(kCPU0);
// Create a FusedBatchNorm op that has multiple output ports.
- auto x = tensorflow::ops::RandomUniform(
+ auto x = ops::RandomUniform(
s.WithOpName("x"), {batch_size_, width_, height_, depth_in_}, DT_FLOAT);
- auto scale = tensorflow::ops::RandomUniform(s.WithOpName("scale"),
- {depth_in_}, DT_FLOAT);
- auto offset = tensorflow::ops::RandomUniform(s.WithOpName("offset"),
- {depth_in_}, DT_FLOAT);
- auto mean =
- tensorflow::ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
- auto var =
- tensorflow::ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
-
- auto batch_norm = tensorflow::ops::FusedBatchNorm(
+ auto scale =
+ ops::RandomUniform(s.WithOpName("scale"), {depth_in_}, DT_FLOAT);
+ auto offset =
+ ops::RandomUniform(s.WithOpName("offset"), {depth_in_}, DT_FLOAT);
+ auto mean = ops::RandomUniform(s.WithOpName("mean"), {0}, DT_FLOAT);
+ auto var = ops::RandomUniform(s.WithOpName("var"), {0}, DT_FLOAT);
+
+ auto batch_norm = ops::FusedBatchNorm(
s.WithOpName("bn"), x, scale, offset, mean, var,
ops::FusedBatchNorm::IsTraining(true).Epsilon(0.1f));
auto y = batch_norm.y;
@@ -284,20 +273,18 @@ class VirtualSchedulerTest : public ::testing::Test {
// Copy FusedBatchNorm's outputs to CPU1.
// y1 and y2 take the same tensor, so there should be only 1 Send and Recv.
- auto y1 =
- tensorflow::ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
- auto y2 =
- tensorflow::ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
+ auto y1 = ops::Identity(s.WithOpName("y1").WithDevice(kCPU1), y);
+ auto y2 = ops::Identity(s.WithOpName("y2").WithDevice(kCPU1), y);
// batch_mean1 and batch_var1 take different output ports, so each will
// initiate Send/Recv.
- auto batch_mean1 = tensorflow::ops::Identity(
+ auto batch_mean1 = ops::Identity(
s.WithOpName("batch_mean1").WithDevice(kCPU1), batch_mean);
- auto batch_var1 = tensorflow::ops::Identity(
- s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
+ auto batch_var1 =
+ ops::Identity(s.WithOpName("batch_var1").WithDevice(kCPU1), batch_var);
// This is control dependency.
- auto control_dep = tensorflow::ops::NoOp(s.WithOpName("control_dep")
- .WithControlDependencies(y)
- .WithDevice(kCPU1));
+ auto control_dep = ops::NoOp(s.WithOpName("control_dep")
+ .WithControlDependencies(y)
+ .WithDevice(kCPU1));
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));
@@ -316,6 +303,402 @@ class VirtualSchedulerTest : public ::testing::Test {
dependency_["control_dep"] = {"bn"};
}
+ // A simple while loop
+ void CreateGrapplerItemWithLoop() {
+ // Test graph produced in python using:
+ /*
+ with tf.Graph().as_default():
+ i0 = tf.constant(0)
+ m0 = tf.ones([2, 2])
+ c = lambda i, m: i < 10
+ b = lambda i, m: [i+1, tf.concat([m, m], axis=0)]
+ r = tf.while_loop(
+ c, b, loop_vars=[i0, m0],
+ shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])])
+ with open('/tmp/graph.pbtxt', 'w') as f:
+ f.write(str(tf.get_default_graph().as_graph_def()))
+ */
+ const string gdef_ascii = R"EOF(
+node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "ones"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 2
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+}
+node {
+ name: "while/Enter"
+ op: "Enter"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "frame_name"
+ value {
+ s: "while/while/"
+ }
+ }
+ attr {
+ key: "is_constant"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "parallel_iterations"
+ value {
+ i: 10
+ }
+ }
+}
+node {
+ name: "while/Enter_1"
+ op: "Enter"
+ input: "ones"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "frame_name"
+ value {
+ s: "while/while/"
+ }
+ }
+ attr {
+ key: "is_constant"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "parallel_iterations"
+ value {
+ i: 10
+ }
+ }
+}
+node {
+ name: "while/Merge"
+ op: "Merge"
+ input: "while/Enter"
+ input: "while/NextIteration"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "while/Merge_1"
+ op: "Merge"
+ input: "while/Enter_1"
+ input: "while/NextIteration_1"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "while/Less/y"
+ op: "Const"
+ input: "^while/Merge"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 10
+ }
+ }
+ }
+}
+node {
+ name: "while/Less"
+ op: "Less"
+ input: "while/Merge"
+ input: "while/Less/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "while/LoopCond"
+ op: "LoopCond"
+ input: "while/Less"
+}
+node {
+ name: "while/Switch"
+ op: "Switch"
+ input: "while/Merge"
+ input: "while/LoopCond"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@while/Merge"
+ }
+ }
+ }
+}
+node {
+ name: "while/Switch_1"
+ op: "Switch"
+ input: "while/Merge_1"
+ input: "while/LoopCond"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_class"
+ value {
+ list {
+ s: "loc:@while/Merge_1"
+ }
+ }
+ }
+}
+node {
+ name: "while/Identity"
+ op: "Identity"
+ input: "while/Switch:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "while/Identity_1"
+ op: "Identity"
+ input: "while/Switch_1:1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "while/add/y"
+ op: "Const"
+ input: "^while/Identity"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+}
+node {
+ name: "while/add"
+ op: "Add"
+ input: "while/Identity"
+ input: "while/add/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "while/concat/axis"
+ op: "Const"
+ input: "^while/Identity"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "while/concat"
+ op: "ConcatV2"
+ input: "while/Identity_1"
+ input: "while/Identity_1"
+ input: "while/concat/axis"
+ attr {
+ key: "N"
+ value {
+ i: 2
+ }
+ }
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "while/NextIteration"
+ op: "NextIteration"
+ input: "while/add"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "while/NextIteration_1"
+ op: "NextIteration"
+ input: "while/concat"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+node {
+ name: "while/Exit"
+ op: "Exit"
+ input: "while/Switch"
+ attr {
+ key: "T"
+ value {
+ type: DT_INT32
+ }
+ }
+}
+node {
+ name: "while/Exit_1"
+ op: "Exit"
+ input: "while/Switch_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+versions {
+ producer: 21
+}
+ )EOF";
+
+ grappler_item_.reset(new GrapplerItem);
+ CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii,
+ &grappler_item_->graph));
+ grappler_item_->id = "test_graph";
+ grappler_item_->fetch = {"while/Exit", "while/Exit_1"};
+ }
+
// Call this after creating grappler_item_ and setting up dependency_.
void InitScheduler() {
scheduler_.reset(new TestVirtualScheduler(
@@ -329,9 +712,10 @@ class VirtualSchedulerTest : public ::testing::Test {
int64 exec_cost = 0;
if (info.op_info.op() == "MatMul") {
exec_cost = 2000000000;
- }
- if (info.op_info.op() == "RandomUniform") {
+ } else if (info.op_info.op() == "RandomUniform") {
exec_cost = 1000000000;
+ } else {
+ exec_cost = 1000;
}
c.execution_time = Costs::NanoSeconds(exec_cost);
return c;
@@ -613,10 +997,11 @@ TEST_F(VirtualSchedulerTest, SummaryCostTest) {
auto ops_executed = RunScheduler("");
Costs c = scheduler_->Summary();
- // RandomUniform - 5
- // Matmuls - 4 * 2 = 8
- // Total: 13
- EXPECT_EQ(13000000, c.execution_time.asMicroSeconds().count());
+ // RandomUniform - 5 * 1s
+ // Matmuls - 4 * 2s = 8
+ // Misc - 5 * 1us
+ // Total: 13000005
+ EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
}
// Like the above SummaryCostTest, but makes sure the stepstats timeline is
@@ -629,7 +1014,7 @@ TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
RunMetadata metadata;
Costs c = scheduler_->Summary(&metadata);
StepStats stepstats = metadata.step_stats();
- EXPECT_EQ(13000000, c.execution_time.asMicroSeconds().count());
+ EXPECT_EQ(13000005, c.execution_time.asMicroSeconds().count());
// Should only be 1 device!
EXPECT_EQ(1, stepstats.dev_stats().size());
@@ -661,7 +1046,7 @@ TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) {
}
// The base start_time is the time to compute RandomUniforms
- int64 cur_time = static_cast<int64>(5000000);
+ int64 cur_time = static_cast<int64>(5000005);
// The increment is the execution time of one matmul. See
// CreateGrapplerItemWithMatmulChain for details.
int64 increment = static_cast<int64>(2000000);
@@ -935,5 +1320,49 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
EXPECT_EQ(get_output_size(recv_op_names[-1]), 4);
EXPECT_EQ(get_output_size(send_op_names[-1]), 4);
}
+
+TEST_F(VirtualSchedulerTest, WhileLoop) {
+ // Init.
+ CreateGrapplerItemWithLoop();
+ InitScheduler();
+
+ // Run the scheduler.
+ RunScheduler("");
+
+ // Check the timeline
+ RunMetadata metadata;
+ scheduler_->Summary(&metadata);
+
+ int num_next_iteration = 0;
+ int num_next_iteration_1 = 0;
+ int num_exit = 0;
+ int num_exit_1 = 0;
+ for (const auto& device_step_stats : metadata.step_stats().dev_stats()) {
+ for (const auto& stats : device_step_stats.node_stats()) {
+ std::cout << stats.DebugString() << std::endl;
+ if (stats.node_name() == "while/NextIteration") {
+ ++num_next_iteration;
+ EXPECT_EQ(19, stats.all_start_micros());
+ } else if (stats.node_name() == "while/NextIteration_1") {
+ ++num_next_iteration_1;
+ EXPECT_EQ(20, stats.all_start_micros());
+ } else if (stats.node_name() == "while/Exit") {
+ ++num_exit;
+ EXPECT_EQ(14, stats.all_start_micros());
+ } else if (stats.node_name() == "while/Exit_1") {
+ ++num_exit_1;
+ EXPECT_EQ(12, stats.all_start_micros());
+ }
+ }
+ }
+
+ // Make sure we went though the body of the loop once, and that the output of
+ // the loop was scheduled as well.
+ EXPECT_EQ(1, num_next_iteration);
+ EXPECT_EQ(1, num_next_iteration_1);
+ EXPECT_EQ(1, num_exit);
+ EXPECT_EQ(1, num_exit_1);
+}
+
} // end namespace grappler
} // end namespace tensorflow