aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-05-02 12:18:25 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-02 13:20:47 -0700
commit99278f5d223385b5a3adb1742c61c4351d95dd37 (patch)
tree3618ca0fc8b5f58d213e972b09e8c011727932db /tensorflow/core
parent7233ae61b0e6deb896dd432764e9c0efae340791 (diff)
SimplePlacer: remove obsolete / never used @colocation device name (superceded
by colocation_groups in _class attr), cleanup calls to pass in name to id map, which is no longer needed in SimplePlacer. Should speed up graph construction in C++ because we don't need to iterate over all of the nodes once to build the map. In the future, utilities should rely on node ids instead of node names so the map is not necessary (ideally). Alternatively, a structure in Graph should maintain the mapping. Change: 121302549
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc7
-rw-r--r--tensorflow/core/common_runtime/simple_placer.cc70
-rw-r--r--tensorflow/core/common_runtime/simple_placer.h19
-rw-r--r--tensorflow/core/common_runtime/simple_placer_test.cc135
-rw-r--r--tensorflow/core/distributed_runtime/simple_graph_execution_state.cc23
-rw-r--r--tensorflow/core/distributed_runtime/simple_graph_execution_state.h3
6 files changed, 23 insertions, 234 deletions
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index e505558156..05ca2e8702 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -849,12 +849,7 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
device_set_.client_device()->attributes()));
// Run the simple placer after rewriting the graph.
- std::unordered_map<string, int32> node_name_to_cost_map;
- for (Node* n : graph->nodes()) {
- node_name_to_cost_map[n->name()] = n->cost_id();
- }
- SimplePlacer placer(graph.get(), &device_set_, &node_name_to_cost_map,
- &options_);
+ SimplePlacer placer(graph.get(), &device_set_, &options_);
{
mutex_lock l(mu_);
diff --git a/tensorflow/core/common_runtime/simple_placer.cc b/tensorflow/core/common_runtime/simple_placer.cc
index f2de6be0f0..9dd4981b5f 100644
--- a/tensorflow/core/common_runtime/simple_placer.cc
+++ b/tensorflow/core/common_runtime/simple_placer.cc
@@ -67,23 +67,6 @@ std::vector<Device*> FilterSupportedDevices(
return filtered_devices;
}
-// TODO(vrv): Remove "@" syntax capability.
-bool HasColocatedNodeName(const Node& node) {
- return StringPiece(node.def().device()).starts_with("@");
-}
-
-Status ParseColocatedNodeName(const Node& node,
- string* out_colocated_node_name) {
- StringPiece device(node.def().device());
- if (!device.Consume("@")) {
- return errors::InvalidArgument("Malformed colocated node name: '", device,
- "'");
- }
- // TODO(mrry): Validate that the node name is a valid node name.
- *out_colocated_node_name = device.ToString();
- return Status::OK();
-}
-
// Returns the name of the colocation group of the node by inspecting
// the "_class" attribute of the NodeDef. Returns "" if it doesn't
// exist.
@@ -484,11 +467,10 @@ class ColocationGraph {
node.def().op(), "' with these attrs");
}
- // If the NodeDef contains a device that is *not* a colocated node name
- // (i.e. it does not begin with '@') then we interpret it as a (partial)
- // device specification.
+ // If the NodeDef contains a device, then we interpret it as a
+ // (partial) device specification.
string colocated_node_name;
- if (!node.def().device().empty() && !HasColocatedNodeName(node)) {
+ if (!node.def().device().empty()) {
// The user has specified a device in the NodeDef, try to find a
// valid device matching their specification in the set of
// devices.
@@ -551,16 +533,13 @@ class ColocationGraph {
} // namespace
SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
- const NodeNameToIdMap* name_to_id_map,
const SessionOptions* options)
: graph_(graph),
devices_(devices),
- name_to_id_map_(name_to_id_map),
options_(options) {}
-SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices,
- const NodeNameToIdMap* name_to_id_map)
- : graph_(graph), devices_(devices), name_to_id_map_(name_to_id_map) {
+SimplePlacer::SimplePlacer(Graph* graph, const DeviceSet* devices)
+ : graph_(graph), devices_(devices) {
options_ = nullptr;
}
@@ -593,33 +572,7 @@ Status SimplePlacer::Run() {
continue;
}
- // 2(a). If node n specifies a colocation constraint as its device name,
- // add an edge from the colocated node to n.
- if (HasColocatedNodeName(*node)) {
- string colocated_node_name;
- status = ParseColocatedNodeName(*node, &colocated_node_name);
- if (!status.ok()) {
- return AttachDef(status, node->def());
- }
- Node* colocated_node;
- status = GetNodeByName(colocated_node_name, &colocated_node);
- if (!status.ok()) {
- return AttachDef(
- errors::InvalidArgument("Colocated node named in device '",
- colocated_node_name, "' does not exist"),
- node->def());
- }
- status = colocation_graph.ColocateNodes(*colocated_node, *node);
- if (!status.ok()) {
- return AttachDef(
- errors::InvalidArgument(
- "Cannot satisfy colocation constraint named in device '",
- colocated_node_name, "': ", status.error_message()),
- node->def());
- }
- }
-
- // 2(b). If `node` has an input edge with reference type, add an
+ // If `node` has an input edge with reference type, add an
// edge from the source of that edge to `node`.
for (const auto& edge : node->in_edges()) {
if (!edge->IsControlEdge() &&
@@ -700,15 +653,4 @@ Status SimplePlacer::Run() {
return Status::OK();
}
-Status SimplePlacer::GetNodeByName(const string& name, Node** out_node) const {
- NodeNameToIdMap::const_iterator iter = name_to_id_map_->find(name);
- if (iter != name_to_id_map_->end()) {
- *out_node = graph_->FindNodeId(iter->second);
- if (*out_node) {
- return Status::OK();
- }
- }
- return errors::NotFound(name);
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/simple_placer.h b/tensorflow/core/common_runtime/simple_placer.h
index 8ace1805cd..0fcaa31e6f 100644
--- a/tensorflow/core/common_runtime/simple_placer.h
+++ b/tensorflow/core/common_runtime/simple_placer.h
@@ -37,8 +37,8 @@ namespace tensorflow {
// are granted.
// 3. Nodes connected by edges of a reference type are colocated on
// the same device.
-// 4. Given nodes "A" and "B", if node "B" has the device specification
-// "@A", nodes "A" and "B" will be colocated on the same device.
+// 4. Given nodes "A" and "B", if node "B" has a colocation group
+// "@loc:A", nodes "A" and "B" will be colocated on the same device.
//
// The implementation builds a constraint graph with the same set of
// nodes, and edges that represent colocation constraints between
@@ -57,20 +57,14 @@ class SimplePlacer {
// Creates an instance of the SimplePlacer algorithm for the given
// Graph "graph" (nodes in which may or may not be assigned) on the
- // given DeviceSet "devices". The "name_to_id_map" maps the names of
- // nodes in "g" to their numerical ID.
+ // given DeviceSet "devices".
//
- // REQUIRES: for all mappings (k, v) in "name_to_id_map",
- // graph.FindNodeId(v)->name() == k.
- //
- // The "graph", "devices", and "name_to_id_map" pointer arguments
+ // The "graph", and "devices" pointer arguments
// are borrowed by this SimplePlacer, and must outlive it.
SimplePlacer(Graph* graph, const DeviceSet* devices,
- const NodeNameToIdMap* name_to_id_map,
const SessionOptions* options);
- SimplePlacer(Graph* graph, const DeviceSet* devices,
- const NodeNameToIdMap* name_to_id_map);
+ SimplePlacer(Graph* graph, const DeviceSet* devices);
~SimplePlacer();
@@ -82,11 +76,8 @@ class SimplePlacer {
Status Run();
private:
- Status GetNodeByName(const string& name, Node** out_node) const;
-
Graph* const graph_; // Not owned.
const DeviceSet* const devices_; // Not owned.
- const NodeNameToIdMap* const name_to_id_map_; // Not owned.
const SessionOptions* options_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(SimplePlacer);
diff --git a/tensorflow/core/common_runtime/simple_placer_test.cc b/tensorflow/core/common_runtime/simple_placer_test.cc
index ac3e7dc34f..cc726f376c 100644
--- a/tensorflow/core/common_runtime/simple_placer_test.cc
+++ b/tensorflow/core/common_runtime/simple_placer_test.cc
@@ -185,7 +185,7 @@ class SimplePlacerTest : public ::testing::Test {
//
// REQUIRES: "*graph" was produced by the most recent call to BuildGraph.
Status Place(Graph* graph, DeviceSet* devices, SessionOptions* options) {
- SimplePlacer placer(graph, devices, &nodes_by_name_, options);
+ SimplePlacer placer(graph, devices, options);
return placer.Run();
}
@@ -512,70 +512,6 @@ TEST_F(SimplePlacerTest, TestReferenceConnectionNoSourceDevice) {
EXPECT_DEVICE_TYPE(g, "assign", DEVICE_CPU);
}
-// Test the handling of '@node_name' colocation constraints, when
-// these are arranged in multiple chains.
-TEST_F(SimplePlacerTest, TestColocatedChain) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
- Node* last_node = input;
- for (int i = 0; i < 100; ++i) {
- if (i % 10 == 0) {
- // Every ten nodes, start a new chain.
- last_node = ops::UnaryOp("TestRelu", last_node,
- b.opts().WithName(strings::StrCat("n_", i)));
- } else {
- // Chain each successive node to the previous one.
- last_node =
- ops::UnaryOp("TestRelu", last_node,
- b.opts()
- .WithName(strings::StrCat("n_", i))
- .WithDevice(strings::StrCat("@n_", i - 1)));
- }
- }
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- TF_EXPECT_OK(Place(&g));
- for (int i = 0; i < 100; ++i) {
- if (i % 10 != 0) {
- EXPECT_COLOCATED(g, strings::StrCat("n_", i - (i % 1)),
- strings::StrCat("n_", i));
- }
- }
-}
-
-// Test the handling of '@node_name' colocation constraints, when the
-// chains are shuffled.
-TEST_F(SimplePlacerTest, TestColocatedChainWithLongRangeColocations) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- Node* input = ops::SourceOp("TestInput", b.opts().WithName("in"));
- Node* last_node = input;
- for (int i = 0; i < 10; ++i) {
- // Start ten chains.
- last_node = ops::UnaryOp("TestRelu", last_node,
- b.opts().WithName(strings::StrCat("n_", i)));
- }
- for (int i = 10; i < 100; ++i) {
- // Add each node to the (i % 10)^th chain.
- last_node = ops::UnaryOp("TestRelu", last_node,
- b.opts()
- .WithName(strings::StrCat("n_", i))
- .WithDevice(strings::StrCat("@n_", i % 10)));
- }
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- TF_EXPECT_OK(Place(&g));
- for (int i = 10; i < 100; ++i) {
- EXPECT_COLOCATED(g, strings::StrCat("n_", i % 10),
- strings::StrCat("n_", i));
- }
-}
-
TEST_F(SimplePlacerTest, TestColocationGroup) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
@@ -724,13 +660,15 @@ TEST_F(SimplePlacerTest, TestColocationAndReferenceConnections) {
// Create a variable colocated with some existing variable, and
// an assignment colocated with a possibly-different variable.
Node* var = ops::SourceOp(
- "TestVariable", b.opts()
- .WithName(strings::StrCat("var_", i))
- .WithDevice(strings::StrCat("@var_", i % 6)));
- ops::BinaryOp("TestAssign", var, input,
- b.opts()
- .WithName(strings::StrCat("assign_", i))
- .WithDevice(strings::StrCat("@assign_", i % 3)));
+ "TestVariable",
+ b.opts()
+ .WithName(strings::StrCat("var_", i))
+ .WithAttr("_class", {strings::StrCat("loc:@var_", i % 6)}));
+ ops::BinaryOp(
+ "TestAssign", var, input,
+ b.opts()
+ .WithName(strings::StrCat("assign_", i))
+ .WithAttr("_class", {strings::StrCat("loc:@assign_", i % 3)}));
}
TF_EXPECT_OK(BuildGraph(b, &g));
}
@@ -938,37 +876,6 @@ TEST_F(SimplePlacerTest, TestNonUniqueAssignedDevice) {
.contains("Assigned device '/job:a' does not match any device"));
}
-// Test that placement fails when a node requests colocation with another
-// node that does not exist.
-TEST_F(SimplePlacerTest, TestUnknownColocatedNode) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("@foo"));
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- Status s = Place(&g);
- EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message()).contains("'foo' does not exist"));
-}
-
-// Test that placement fails when a node requests colocation with a
-// malformed node name.
-TEST_F(SimplePlacerTest, TestMalformedColocatedNode) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- ops::SourceOp("TestInput", b.opts().WithName("in").WithDevice("@"));
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- Status s = Place(&g);
- EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("node named in device '' does not exist"));
-}
-
// Test that ops request to be placed on non-existent devices will be relocated
// to existing device of the same type if allow_soft_placement is set.
TEST_F(SimplePlacerTest, TestNonexistentGpuAllowSoftPlacement) {
@@ -1113,27 +1020,5 @@ TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
.contains("Cannot colocate nodes 'var' and 'assign'"));
}
-// Test that placement fails when two nodes have an explicit
-// colocation constraint, and each node requires a mutually
-// incompatible device.
-TEST_F(SimplePlacerTest, TestUnsatisfiableConstraintWithColocatedNodes) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- Node* input = ops::SourceOp("TestInput",
- b.opts().WithName("in").WithDevice("/gpu:0"));
- Node* relu_1 = ops::UnaryOp("TestRelu", input,
- b.opts().WithName("relu_1").WithDevice("@in"));
- ops::UnaryOp("ReluGPU", relu_1,
- b.opts().WithName("relu_2").WithDevice("@relu_1"));
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- Status s = Place(&g);
- EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(StringPiece(s.error_message())
- .contains("Cannot colocate nodes 'relu_1' and 'relu_2'"));
-}
-
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc b/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc
index 62d1f5903c..2b24a59c6f 100644
--- a/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc
+++ b/tensorflow/core/distributed_runtime/simple_graph_execution_state.cc
@@ -129,35 +129,15 @@ Status SimpleGraphExecutionState::InitBaseGraph() {
GraphConstructorOptions opts;
TF_RETURN_IF_ERROR(
ConvertGraphDefToGraph(opts, original_graph_def_, new_base.get()));
- for (const Node* n : new_base->nodes()) {
- VLOG(2) << "Mapping " << n->name() << " to " << n->cost_id();
- node_name_to_cost_id_map_[n->name()] = n->cost_id();
- }
Status status = PreliminaryPlace(*new_base);
if (!status.ok()) {
- node_name_to_cost_id_map_.clear();
return status;
}
base_ = new_base.release();
return Status::OK();
}
-Status SimpleGraphExecutionState::GlobalNodeDefByName(const string& name,
- NodeDef* out) {
- NodeNameToCostIdMap::const_iterator iter =
- node_name_to_cost_id_map_.find(name);
- if (iter != node_name_to_cost_id_map_.end()) {
- mutex_lock l(mu_); // could use reader lock
- const Node* node = placed_->FindNodeId(iter->second);
- if (node) {
- *out = node->def();
- return Status::OK();
- }
- }
- return errors::NotFound("Node name: ", name);
-}
-
Status SimpleGraphExecutionState::PreliminaryPlace(const Graph& base) {
VLOG(1) << "PreliminaryPlace";
Graph* ng = new Graph(ops_);
@@ -284,8 +264,7 @@ Status SimpleGraphExecutionState::DeviceIsCompatible(
}
Status SimpleGraphExecutionState::SimplePlacement(Graph* graph) {
- SimplePlacer placer(graph, device_set_, &node_name_to_cost_id_map_,
- session_options_);
+ SimplePlacer placer(graph, device_set_, session_options_);
// TODO(mrry): Consider making the SimplePlacer cancelable.
return placer.Run();
}
diff --git a/tensorflow/core/distributed_runtime/simple_graph_execution_state.h b/tensorflow/core/distributed_runtime/simple_graph_execution_state.h
index 6d065437d8..95fcc6130c 100644
--- a/tensorflow/core/distributed_runtime/simple_graph_execution_state.h
+++ b/tensorflow/core/distributed_runtime/simple_graph_execution_state.h
@@ -145,9 +145,6 @@ class SimpleGraphExecutionState {
PlaceMap stateful_placements_ GUARDED_BY(mu_);
std::vector<Node*> missing_stateful_placements_ GUARDED_BY(mu_);
- // Map from name to Node for the full graph in placed_.
- NodeNameToCostIdMap node_name_to_cost_id_map_;
-
TF_DISALLOW_COPY_AND_ASSIGN(SimpleGraphExecutionState);
};