aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/common_runtime/simple_placer_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/common_runtime/simple_placer_test.cc')
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc71
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
index 06267d71ae..c73ed041ed 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -1226,5 +1226,76 @@ TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
.contains("Cannot colocate nodes 'var' and 'assign'"));
}
+// Test that a generator node follows its consumers (where there are several
+// consumer nodes on the same devices).
+TEST_F(SimplePlacerTest, TestGeneratorNodeFollowsConsumerNode) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+
+ // A variable is only on CPU
+ Node* var1_cpu =
+ ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu"));
+ Node* var2_cpu =
+ ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu"));
+
+ // The constant to be assigned can be on both GPU or CPU.
+ //
+ // Because of the heuristic, it gets placed on CPU to avoid a
+ // copy.
+ Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
+
+ // The assigns are bound to CPU by the reference edge.
+ ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1"));
+ ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2"));
+
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_EXPECT_OK(Place(&g));
+ EXPECT_COLOCATED(g, "var1_cpu", "in");
+ EXPECT_COLOCATED(g, "assign1", "in");
+ EXPECT_COLOCATED(g, "var2_cpu", "in");
+ EXPECT_COLOCATED(g, "assign2", "in");
+}
+
+// Test that a generator node does not follow its consumers (where there are
+// several consumers on different devices).
+TEST_F(SimplePlacerTest, TestGeneratorNodeDoesntFollowNonColocatedConsumers) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+
+ // A variable is only on CPU
+ Node* var1_cpu =
+ ops::SourceOp("VariableCPU", b.opts().WithName("var1_cpu"));
+ Node* var2_cpu =
+ ops::SourceOp("VariableCPU", b.opts().WithName("var2_cpu"));
+
+ // The constant to be assigned can be on both GPU or CPU.
+ //
+ // Because of the heuristic, it ought to be on the GPU (cannot be
+ // co-located with both consumers, so goes to the 'standard' place)
+ Node* input = ops::SourceOp("TestCPUGPUOutput", b.opts().WithName("in"));
+
+ // The assigns are bound to CPU by the reference edge.
+ ops::BinaryOp("TestAssign", var1_cpu, input, b.opts().WithName("assign1"));
+ ops::BinaryOp("TestAssign", var2_cpu, input, b.opts().WithName("assign2"));
+
+ TF_EXPECT_OK(BuildGraph(b, &g));
+
+ GetNodeByName(g, "var1_cpu")
+ ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:1");
+
+ GetNodeByName(g, "var2_cpu")
+ ->set_assigned_device_name("/job:a/replica:0/task:0/device:fakecpu:2");
+ }
+
+ TF_EXPECT_OK(Place(&g));
+ EXPECT_COLOCATED(g, "assign1", "var1_cpu");
+ EXPECT_COLOCATED(g, "assign2", "var2_cpu");
+ EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
+}
+
} // namespace
} // namespace tensorflow