diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-02-24 09:59:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-02-24 15:33:10 -0800 |
commit | 9d84271a2039918994e57c2f962d2ee656f01541 (patch) | |
tree | 4d6507d1414ab9e644bfcbcec835c3c40eb03610 /tensorflow/core/common_runtime/simple_placer_test.cc | |
parent | 92383c8754179375ef1e91f270cd60a126cf77c4 (diff) |
TensorFlow: Initial support in SimplePlacer for colocation groups,
to be used to colocate based on attributes rather than either
names of ops or devices (op names and devices aren't portable).
A follow up change will add an ops.colocate_with() to Python that adds
this attribute to nodes, and will be used to replace calls to 'with
tf.device(foo.device)' in TF library code, which assumes that devices
have been specified.
Change: 115463464
Diffstat (limited to 'tensorflow/core/common_runtime/simple_placer_test.cc')
-rw-r--r-- | tensorflow/core/common_runtime/simple_placer_test.cc | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc index 14871c7a29..5dfbb60c2b 100644 --- a/tensorflow/core/common_runtime/simple_placer_test.cc +++ b/tensorflow/core/common_runtime/simple_placer_test.cc @@ -218,6 +218,13 @@ class SimplePlacerTest : public ::testing::Test { GetNodeByName(g_, (name_b))->assigned_device_name()); \ } while (0) +#define EXPECT_NOT_COLOCATED(g, name_a, name_b) \ + do { \ + Graph& g_ = (g); \ + EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(), \ + GetNodeByName(g_, (name_b))->assigned_device_name()); \ + } while (0) + #define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \ EXPECT_EQ(DeviceType(expected_device_type).type(), \ devices_.FindDeviceByName( \ @@ -473,6 +480,90 @@ TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) { } } +TEST_F(SimplePlacerTest, TestColocationGroup) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp( + "TestInput", b.opts().WithName("in").WithAttr("_class", "loc:ti")); + Node* colocated_with_input = ops::UnaryOp( + "TestRelu", input, + b.opts().WithName("colocated_1").WithAttr("_class", "loc:ti")); + + // This will not be colocated with the input because TestInput is + // only availbale on CPU and TestRelu will default to GPU. + Node* not_colocated_with_input = + ops::UnaryOp("TestRelu", input, b.opts().WithName("foo")); + CHECK(colocated_with_input); + CHECK(not_colocated_with_input); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "in", "colocated_1"); + EXPECT_NOT_COLOCATED(g, "in", "foo"); +} + +TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1")); + Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2")); + + // Two assigns (reference connections) with two different + // colocation groups. Because their colocation groups all map to the + // same device, this is a valid assignment. + ops::BinaryOp("TestAssign", var1, input, + b.opts().WithName("assign1").WithAttr("_class", "loc:1")); + ops::BinaryOp("TestAssign", var2, input, + b.opts().WithName("assign2").WithAttr("_class", "loc:2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + TF_EXPECT_OK(Place(&g)); + EXPECT_COLOCATED(g, "in", "var1"); + EXPECT_COLOCATED(g, "in", "var2"); + EXPECT_COLOCATED(g, "var1", "assign2"); + EXPECT_COLOCATED(g, "var2", "assign1"); +} + +TEST_F(SimplePlacerTest, + TestColocationGroupWithUnsatisfiableReferenceConnections) { + Graph g(OpRegistry::Global()); + { // Scope for temporary variables used to construct g. + GraphDefBuilder b(GraphDefBuilder::kFailImmediately); + Node* input = ops::SourceOp("TestInput", b.opts().WithName("in")); + + Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1")); + Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2")); + // Var 3 is on GPU + Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3")); + + // Two assigns (reference connections) with two different + // colocation groups. Because their colocation groups all map to the + // same device, this is a valid assignment. + ops::BinaryOp("TestAssign", var1, input, + b.opts().WithName("assign1").WithAttr("_class", "loc:1")); + ops::BinaryOp("TestAssign", var2, input, + b.opts().WithName("assign2").WithAttr("_class", "loc:2")); + // Assign to var3, but try to use a colocation group that matches + // the assign of var2. This should fail because assign2 must be on CPU + // (it has a reference edge on var2), and assign3 must be on GPU, + // hence the conflict. + ops::BinaryOp("TestAssign", var3, input, + b.opts().WithName("assign3").WithAttr("_class", "loc:2")); + TF_EXPECT_OK(BuildGraph(b, &g)); + } + + Status s = Place(&g); + EXPECT_TRUE( + StringPiece(s.error_message()) + .contains("Cannot assign a device to node 'var3': Node had no " + "OpKernel registered")); +} + TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) { Graph g(OpRegistry::Global()); { // Scope for temporary variables used to construct g. |