aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Ayush Dubey <ayushd@google.com>2018-10-03 13:25:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 13:36:30 -0700
commit261b6958fb95db18cd28c1aba140a627deb790a1 (patch)
treec8fb6ce5a685518edd61668dd764dc40fd3f5cf5 /tensorflow/core
parentc2c8cfe22492cf7fab804d32283b623632270035 (diff)
Enable collective graph key test for GPU builds.
In the process, properly place nodes on devices in the collective graph key test. PiperOrigin-RevId: 215616146
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc58
1 files changed, 26 insertions, 32 deletions
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index e3e431f800..a6440c55ad 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -2262,8 +2262,8 @@ class DirectSessionCollectiveTest : public ::testing::Test {
TF_RETURN_IF_ERROR(session->Create(g));
std::vector<Tensor> outputs;
TF_RETURN_IF_ERROR(
- session->Run({{"input1:0", t1}, {"input2:0", t2}}, {},
- {"collective_call1:0", "collective_call2:0"}, &outputs));
+ session->Run({{"input0:0", t1}, {"input1:0", t2}}, {},
+ {"collective_call0:0", "collective_call1:0"}, &outputs));
DirectSession* direct_session = static_cast<DirectSession*>(session.get());
{
mutex_lock l(direct_session->collective_graph_key_lock_);
@@ -2301,6 +2301,26 @@ class DirectSessionCollectiveTest : public ::testing::Test {
}});
}
+ NodeDef Input(int id) {
+ AttrValue dtype_attr;
+ SetAttrValue(DT_FLOAT, &dtype_attr);
+ NodeDef input;
+ input.set_name(strings::StrCat("input", id));
+ input.set_op("Placeholder");
+ input.mutable_attr()->insert({"dtype", dtype_attr});
+ return input;
+ }
+
+ NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) {
+ NodeDef collective_call;
+ collective_call.set_name(strings::StrCat("collective_call", cpu_id));
+ collective_call.set_op(op);
+ collective_call.add_input(input);
+ collective_call.set_device(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id));
+ return collective_call;
+ }
+
// Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
// CPU1, with instance_key 1, and appropriate placeholder inputs. If
// `add_unused_function` is true, adds another CollectiveFunction with
@@ -2317,42 +2337,17 @@ class DirectSessionCollectiveTest : public ::testing::Test {
*lib->add_function() = unused_function;
}
- // Inputs.
- AttrValue dtype_attr;
- SetAttrValue(DT_FLOAT, &dtype_attr);
- NodeDef input1;
- input1.set_name("input1");
- input1.set_op("Placeholder");
- input1.mutable_attr()->insert({"dtype", dtype_attr});
- NodeDef input2;
- input2.set_name("input2");
- input2.set_op("Placeholder");
- input2.mutable_attr()->insert({"dtype", dtype_attr});
-
+ *g.add_node() = Input(0);
+ *g.add_node() = Input(1);
// CollectiveReduce on CPU0 with instance_key 1.
- NodeDef collective_call1;
- collective_call1.set_name("collective_call1");
- collective_call1.set_op("CollectiveFunction1");
- collective_call1.add_input("input1");
- collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0");
+ *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0);
// CollectiveReduce on CPU1 with instance_key 1.
- NodeDef collective_call2;
- collective_call2.set_name("collective_call2");
- collective_call2.set_op("CollectiveFunction1");
- collective_call2.add_input("input2");
- collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1");
-
- *g.add_node() = input1;
- *g.add_node() = input2;
- *g.add_node() = collective_call1;
- *g.add_node() = collective_call2;
+ *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1);
return g;
}
};
-#ifndef GOOGLE_CUDA
-// TODO(ayushd): enable this test for GPU builds.
TEST_F(DirectSessionCollectiveTest,
TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
int64 key1;
@@ -2361,6 +2356,5 @@ TEST_F(DirectSessionCollectiveTest,
TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
ASSERT_EQ(key1, key2);
}
-#endif
} // namespace tensorflow