diff options
Diffstat (limited to 'tensorflow/core/grappler/optimizers/auto_parallel_test.cc')
-rw-r--r-- | tensorflow/core/grappler/optimizers/auto_parallel_test.cc | 42 |
1 files changed, 18 insertions, 24 deletions
diff --git a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc index 9a41b5e0b5..3d1b4a34bf 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_parallel_test.cc @@ -33,7 +33,6 @@ TEST_F(AutoParallelTest, SimpleParallel) { Output constant_b = ops::Const(s.WithOpName("constant_b"), 1, {1}); Output var = ops::Variable(s.WithOpName("var"), {1}, DT_FLOAT); Output assign = ops::Assign(s.WithOpName("assign"), {var}, {constant_a}); - Output identity = ops::Identity(s.WithOpName("identity"), {var}); Output fifo_queue = ops::FIFOQueue(s.WithOpName("fifo_queue"), {DT_FLOAT}); auto dequeue = ops::QueueDequeueMany(s.WithOpName("dequeue"), {fifo_queue}, {constant_b}, {DT_FLOAT}); @@ -45,14 +44,13 @@ TEST_F(AutoParallelTest, SimpleParallel) { GrapplerItem item; item.init_ops.push_back("assign"); item.fetch.push_back("apply_gradient"); - item.init_ops.push_back("assign"); TF_CHECK_OK(s.ToGraphDef(&item.graph)); AutoParallel parallel(2); GraphDef output; Status status = parallel.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(21, output.node_size()); + EXPECT_EQ(20, output.node_size()); const NodeDef& node_assign = output.node(0); EXPECT_EQ("assign", node_assign.name()); @@ -64,64 +62,60 @@ TEST_F(AutoParallelTest, SimpleParallel) { const NodeDef& node_fifo_queue = output.node(2); EXPECT_EQ("fifo_queue", node_fifo_queue.name()); - const NodeDef& node_identity = output.node(3); - EXPECT_EQ("identity", node_identity.name()); - EXPECT_EQ("var", node_identity.input(0)); - - const NodeDef& node_var = output.node(4); + const NodeDef& node_var = output.node(3); EXPECT_EQ("var", node_var.name()); - const NodeDef& node_div_const0 = output.node(5); + const NodeDef& node_div_const0 = output.node(4); EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-Const", node_div_const0.name()); - const NodeDef& node_div0 = output.node(6); + const NodeDef& node_div0 = output.node(5); EXPECT_EQ("AutoParallel-Replica-0/AutoParallel-Div-apply_gradient", node_div0.name()); - const NodeDef& node_add0 = output.node(7); + const NodeDef& node_add0 = output.node(6); EXPECT_EQ("AutoParallel-Replica-0/add", node_add0.name()); - const NodeDef& node_gradient0 = output.node(8); + const NodeDef& node_gradient0 = output.node(7); EXPECT_EQ("AutoParallel-Replica-0/apply_gradient", node_gradient0.name()); - const NodeDef& node_constant_a0 = output.node(9); + const NodeDef& node_constant_a0 = output.node(8); EXPECT_EQ("AutoParallel-Replica-0/constant_a", node_constant_a0.name()); - const NodeDef& node_dequeue0 = output.node(10); + const NodeDef& node_dequeue0 = output.node(9); EXPECT_EQ("AutoParallel-Replica-0/dequeue", node_dequeue0.name()); - const NodeDef& node_learning_rate0 = output.node(11); + const NodeDef& node_learning_rate0 = output.node(10); EXPECT_EQ("AutoParallel-Replica-0/learning_rate", node_learning_rate0.name()); - const NodeDef& node_div_const1 = output.node(12); + const NodeDef& node_div_const1 = output.node(11); EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-Const", node_div_const1.name()); - const NodeDef& node_div1 = output.node(13); + const NodeDef& node_div1 = output.node(12); EXPECT_EQ("AutoParallel-Replica-1/AutoParallel-Div-apply_gradient", node_div1.name()); - const NodeDef& node_add1 = output.node(14); + const NodeDef& node_add1 = output.node(13); EXPECT_EQ("AutoParallel-Replica-1/add", node_add1.name()); - const NodeDef& node_gradient1 = output.node(15); + const NodeDef& node_gradient1 = output.node(14); EXPECT_EQ("AutoParallel-Replica-1/apply_gradient", node_gradient1.name()); - const NodeDef& node_constant_a1 = output.node(16); + const NodeDef& node_constant_a1 = output.node(15); EXPECT_EQ("AutoParallel-Replica-1/constant_a", node_constant_a1.name()); - const NodeDef& node_dequeue1 = output.node(17); + const NodeDef& node_dequeue1 = output.node(16); EXPECT_EQ("AutoParallel-Replica-1/dequeue", node_dequeue1.name()); - const NodeDef& node_learning_rate1 = output.node(18); + const NodeDef& node_learning_rate1 = output.node(17); EXPECT_EQ("AutoParallel-Replica-1/learning_rate", node_learning_rate1.name()); - const NodeDef& node_fetch = output.node(19); + const NodeDef& node_fetch = output.node(18); EXPECT_EQ("AutoParallel-Control-Fetch", node_fetch.name()); EXPECT_EQ("^AutoParallel-Replica-0/apply_gradient", node_fetch.input(0)); EXPECT_EQ("^AutoParallel-Replica-1/apply_gradient", node_fetch.input(1)); - const NodeDef& node_gradient = output.node(20); + const NodeDef& node_gradient = output.node(19); EXPECT_EQ("apply_gradient", node_gradient.name()); EXPECT_EQ("^AutoParallel-Control-Fetch", node_gradient.input(0)); } |