aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-02-24 09:59:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-02-24 15:33:10 -0800
commit9d84271a2039918994e57c2f962d2ee656f01541 (patch)
tree4d6507d1414ab9e644bfcbcec835c3c40eb03610
parent92383c8754179375ef1e91f270cd60a126cf77c4 (diff)
TensorFlow: Initial support in SimplePlacer for colocation groups,
to be used to colocate based on attributes rather than either names of ops or devices (op names and devices aren't portable). A follow up change will add an ops.colocate_with() to Python that adds this attribute to nodes, and will be used to replace calls to 'with tf.device(foo.device)' in TF library code, which assumes that devices have been specified. Change: 115463464
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc47
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc91
2 files changed, 138 insertions, 0 deletions
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index 33a7a63107..37a5ad5efa 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -67,6 +67,7 @@ std::vector<Device*> FilterSupportedDevices(
return filtered_devices;
}
+// TODO(vrv): Remove "@" syntax capability.
bool HasColocatedNodeName(const Node& node) {
return StringPiece(node.def().device()).starts_with("@");
}
@@ -83,6 +84,30 @@ Status ParseColocatedNodeName(const Node& node,
return Status::OK();
}
+// Returns the name of the colocation group of the node by inspecting
+// the "_class" attribute of the NodeDef. Returns "" if it doesn't
+// exist.
+Status ColocationGroup(const Node& node, string* colocation_group) {
+ string class_spec;
+ // TODO(vrv): We should consider adding a GetNodeAttr that returns a
+ // StringPiece, to avoid a copy.
+ Status s = GetNodeAttr(node.def(), "_class", &class_spec);
+ if (!s.ok()) {
+ // No "_class" attribute is equivalent to the empty colocation_group.
+ *colocation_group = "";
+ return Status::OK();
+ }
+
+ StringPiece spec(class_spec);
+ if (!spec.Consume("loc:")) {
+ return errors::InvalidArgument("Node had an invalid _class attribute: ",
+ class_spec);
+ }
+
+ *colocation_group = spec.ToString();
+ return Status::OK();
+}
+
// This class maintains the connected components of a colocation
// constraint graph, and uses this information to assign a satisfying
// device placement to the nodes of the graph.
@@ -134,6 +159,24 @@ class ColocationGraph {
CHECK_GE(member.parent, 0);
members_.resize(member.parent + 1);
members_[member.parent] = std::move(member);
+
+ // When adding the node, identify whether it is part of a
+ // colocation group.
+ string colocation_group;
+ TF_RETURN_IF_ERROR(ColocationGroup(node, &colocation_group));
+ if (!colocation_group.empty()) {
+ // Node has a colocation group specified.
+ auto it = colocation_group_root_.find(colocation_group);
+ if (it == colocation_group_root_.end()) {
+ // This is the first node of the colocation group, so
+ // designate this node as the 'root' of that colocation group.
+ colocation_group_root_[colocation_group] = &node;
+ } else {
+ // Colocate this node with the root.
+ ColocateNodes(node, *(it->second));
+ }
+ }
+
return Status::OK();
}
@@ -447,6 +490,10 @@ class ColocationGraph {
const DeviceSet* device_set_; // Not owned.
const std::vector<DeviceType> device_types_;
const SessionOptions* options_; // Not owned;
+
+ // Maps from a colocation group identifier to the 'root' of that
+ // colocation group.
+ std::unordered_map<string, const Node*> colocation_group_root_;
};
} // namespace
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
index 14871c7a29..5dfbb60c2b 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -218,6 +218,13 @@ class SimplePlacerTest : public ::testing::Test {
GetNodeByName(g_, (name_b))->assigned_device_name()); \
} while (0)
+#define EXPECT_NOT_COLOCATED(g, name_a, name_b) \
+ do { \
+ Graph& g_ = (g); \
+ EXPECT_NE(GetNodeByName(g_, (name_a))->assigned_device_name(), \
+ GetNodeByName(g_, (name_b))->assigned_device_name()); \
+ } while (0)
+
#define EXPECT_DEVICE_TYPE(g, name, expected_device_type) \
EXPECT_EQ(DeviceType(expected_device_type).type(), \
devices_.FindDeviceByName( \
@@ -473,6 +480,90 @@ TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) {
}
}
+TEST_F(SimplePlacerTest, TestColocationGroup) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp(
+ "TestInput", b.opts().WithName("in").WithAttr("_class", "loc:ti"));
+ Node* colocated_with_input = ops::UnaryOp(
+ "TestRelu", input,
+ b.opts().WithName("colocated_1").WithAttr("_class", "loc:ti"));
+
+ // This will not be colocated with the input because TestInput is
+ // only availbale on CPU and TestRelu will default to GPU.
+ Node* not_colocated_with_input =
+ ops::UnaryOp("TestRelu", input, b.opts().WithName("foo"));
+ CHECK(colocated_with_input);
+ CHECK(not_colocated_with_input);
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_EXPECT_OK(Place(&g));
+ EXPECT_COLOCATED(g, "in", "colocated_1");
+ EXPECT_NOT_COLOCATED(g, "in", "foo");
+}
+
+TEST_F(SimplePlacerTest, TestColocationGroupWithReferenceConnections) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+ Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
+ Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
+
+ // Two assigns (reference connections) with two different
+ // colocation groups. Because their colocation groups all map to the
+ // same device, this is a valid assignment.
+ ops::BinaryOp("TestAssign", var1, input,
+ b.opts().WithName("assign1").WithAttr("_class", "loc:1"));
+ ops::BinaryOp("TestAssign", var2, input,
+ b.opts().WithName("assign2").WithAttr("_class", "loc:2"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ TF_EXPECT_OK(Place(&g));
+ EXPECT_COLOCATED(g, "in", "var1");
+ EXPECT_COLOCATED(g, "in", "var2");
+ EXPECT_COLOCATED(g, "var1", "assign2");
+ EXPECT_COLOCATED(g, "var2", "assign1");
+}
+
+TEST_F(SimplePlacerTest,
+ TestColocationGroupWithUnsatisfiableReferenceConnections) {
+ Graph g(OpRegistry::Global());
+ { // Scope for temporary variables used to construct g.
+ GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
+ Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
+
+ Node* var1 = ops::SourceOp("VariableCPU", b.opts().WithName("var1"));
+ Node* var2 = ops::SourceOp("VariableCPU", b.opts().WithName("var2"));
+ // Var 3 is on GPU
+ Node* var3 = ops::SourceOp("VariableGPU", b.opts().WithName("var3"));
+
+ // Two assigns (reference connections) with two different
+ // colocation groups. Because their colocation groups all map to the
+ // same device, this is a valid assignment.
+ ops::BinaryOp("TestAssign", var1, input,
+ b.opts().WithName("assign1").WithAttr("_class", "loc:1"));
+ ops::BinaryOp("TestAssign", var2, input,
+ b.opts().WithName("assign2").WithAttr("_class", "loc:2"));
+ // Assign to var3, but try to use a colocation group that matches
+ // the assign of var2. This should fail because assign2 must be on CPU
+ // (it has a reference edge on var2), and assign3 must be on GPU,
+ // hence the conflict.
+ ops::BinaryOp("TestAssign", var3, input,
+ b.opts().WithName("assign3").WithAttr("_class", "loc:2"));
+ TF_EXPECT_OK(BuildGraph(b, &g));
+ }
+
+ Status s = Place(&g);
+ EXPECT_TRUE(
+ StringPiece(s.error_message())
+ .contains("Cannot assign a device to node 'var3': Node had no "
+ "OpKernel registered"));
+}
+
TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.