aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/simple_placer_test.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-02-24 09:59:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-24 15:33:10 -0800
commit9d84271a2039918994e57c2f962d2ee656f01541 (patch)
tree4d6507d1414ab9e644bfcbcec835c3c40eb03610 /tensorflow/core/common_runtime/simple_placer_test.cc
parent92383c8754179375ef1e91f270cd60a126cf77c4 (diff)
TensorFlow: Initial support in SimplePlacer for colocation groups,
to be used to colocate based on attributes rather than either names of ops or devices (op names and devices aren't portable). A follow up change will add an ops.colocate_with() to Python that adds this attribute to nodes, and will be used to replace calls to 'with tf.device(foo.device)' in TF library code, which assumes that devices have been specified. Change: 115463464
Diffstat (limited to 'tensorflow/core/common_runtime/simple_placer_test.cc')
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc91
1 files changed, 91 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
index 14871c7a29..5dfbb60c2b 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -218,6 +218,13 @@ class SimplePlacerTest : public ::testing::Test {
GetNodeByName(g_, (name_b))->assigned_device_name()); \
} while (0)
+#define EXPECT_NOT_COLOCATED(g, name_a, name_b) \
+ do { \
+ Graph& g_ = (g); \
+ EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(), \
+ GetNodeByName(g_, (name_b))->assigned_device_name()); \
+ } while (0)
+
#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \
EXPECT_EQ(DeviceType(expected_device_type).type(), \
devices_.FindDeviceByName( \
@@ -473,6 +480,90 @@ TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) {
}
}
+TEST_F(SimplePlacerTest, TestColocationGroup) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp(
+ "TestInput", b.opts().WithName("in").WithAttr("_class", "loc:ti"));
+ Node* colocated_with_input = ops::UnaryOp(
+ "TestRelu", input,
+ b.opts().WithName("colocated_1").WithAttr("_class", "loc:ti"));
+
+ // This will not be colocated with the input because TestInput is
+ // only availbale on CPU and TestRelu will default to GPU.
+ Node* not_colocated_with_input =
+ ops::UnaryOp("TestRelu", input, b.opts().WithName("foo"));
+ CHECK(colocated_with_input);
+ CHECK(not_colocated_with_input);
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_EXPECT_OK(Place(&g));
+ EXPECT_COLOCATED(g, "in", "colocated_1");
+ EXPECT_NOT_COLOCATED(g, "in", "foo");
+}
+
+TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
+ Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
+
+ // Two assigns (reference connections) with two different
+ // colocation groups. Because their colocation groups all map to the
+ // same device, this is a valid assignment.
+ ops::BinaryOp("TestAssign", var1, input,
+ b.opts().WithName("assign1").WithAttr("_class", "loc:1"));
+ ops::BinaryOp("TestAssign", var2, input,
+ b.opts().WithName("assign2").WithAttr("_class", "loc:2"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_EXPECT_OK(Place(&g));
+ EXPECT_COLOCATED(g, "in", "var1");
+ EXPECT_COLOCATED(g, "in", "var2");
+ EXPECT_COLOCATED(g, "var1", "assign2");
+ EXPECT_COLOCATED(g, "var2", "assign1");
+}
+
+TEST_F(SimplePlacerTest,
+ TestColocationGroupWithUnsatisfiableReferenceConnections) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+
+ Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
+ Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
+ // Var 3 is on GPU
+ Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3"));
+
+ // Two assigns (reference connections) with two different
+ // colocation groups. Because their colocation groups all map to the
+ // same device, this is a valid assignment.
+ ops::BinaryOp("TestAssign", var1, input,
+ b.opts().WithName("assign1").WithAttr("_class", "loc:1"));
+ ops::BinaryOp("TestAssign", var2, input,
+ b.opts().WithName("assign2").WithAttr("_class", "loc:2"));
+ // Assign to var3, but try to use a colocation group that matches
+ // the assign of var2. This should fail because assign2 must be on CPU
+ // (it has a reference edge on var2), and assign3 must be on GPU,
+ // hence the conflict.
+ ops::BinaryOp("TestAssign", var3, input,
+ b.opts().WithName("assign3").WithAttr("_class", "loc:2"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains("Cannot assign a device to node 'var3': Node had no "
+ "OpKernel registered"));
+}
+
TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.