diff options
-rw-r--r-- | tensorflow/core/common_runtime/direct_session_test.cc | 58 |
1 files changed, 26 insertions, 32 deletions
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index e3e431f800..a6440c55ad 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -2262,8 +2262,8 @@ class DirectSessionCollectiveTest : public ::testing::Test { TF_RETURN_IF_ERROR(session->Create(g)); std::vector<Tensor> outputs; TF_RETURN_IF_ERROR( - session->Run({{"input1:0", t1}, {"input2:0", t2}}, {}, - {"collective_call1:0", "collective_call2:0"}, &outputs)); + session->Run({{"input0:0", t1}, {"input1:0", t2}}, {}, + {"collective_call0:0", "collective_call1:0"}, &outputs)); DirectSession* direct_session = static_cast<DirectSession*>(session.get()); { mutex_lock l(direct_session->collective_graph_key_lock_); @@ -2301,6 +2301,26 @@ class DirectSessionCollectiveTest : public ::testing::Test { }}); } + NodeDef Input(int id) { + AttrValue dtype_attr; + SetAttrValue(DT_FLOAT, &dtype_attr); + NodeDef input; + input.set_name(strings::StrCat("input", id)); + input.set_op("Placeholder"); + input.mutable_attr()->insert({"dtype", dtype_attr}); + return input; + } + + NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) { + NodeDef collective_call; + collective_call.set_name(strings::StrCat("collective_call", cpu_id)); + collective_call.set_op(op); + collective_call.add_input(input); + collective_call.set_device( + strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id)); + return collective_call; + } + // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and // CPU1, with instance_key 1, and appropriate placeholder inputs. If // `add_unused_function` is true, adds another CollectiveFunction with @@ -2317,42 +2337,17 @@ class DirectSessionCollectiveTest : public ::testing::Test { *lib->add_function() = unused_function; } - // Inputs. - AttrValue dtype_attr; - SetAttrValue(DT_FLOAT, &dtype_attr); - NodeDef input1; - input1.set_name("input1"); - input1.set_op("Placeholder"); - input1.mutable_attr()->insert({"dtype", dtype_attr}); - NodeDef input2; - input2.set_name("input2"); - input2.set_op("Placeholder"); - input2.mutable_attr()->insert({"dtype", dtype_attr}); - + *g.add_node() = Input(0); + *g.add_node() = Input(1); // CollectiveReduce on CPU0 with instance_key 1. - NodeDef collective_call1; - collective_call1.set_name("collective_call1"); - collective_call1.set_op("CollectiveFunction1"); - collective_call1.add_input("input1"); - collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0"); + *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0); // CollectiveReduce on CPU1 with instance_key 1. - NodeDef collective_call2; - collective_call2.set_name("collective_call2"); - collective_call2.set_op("CollectiveFunction1"); - collective_call2.add_input("input2"); - collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1"); - - *g.add_node() = input1; - *g.add_node() = input2; - *g.add_node() = collective_call1; - *g.add_node() = collective_call2; + *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1); return g; } }; -#ifndef GOOGLE_CUDA -// TODO(ayushd): enable this test for GPU builds. TEST_F(DirectSessionCollectiveTest, TestCollectiveGraphKeyUsesOnlyCalledFunctions) { int64 key1; @@ -2361,6 +2356,5 @@ TEST_F(DirectSessionCollectiveTest, TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2)); ASSERT_EQ(key1, key2); } -#endif } // namespace tensorflow |