aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-18 12:53:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-18 13:00:35 -0700
commit38bcb3c02fbc5185d6c1fb7e8327a070284b66e4 (patch)
tree39befc19c046875ef9ba9ee675a878e048ebad4f /tensorflow/tools/graph_transforms
parent09ff3f7296a66c39535e097ecb6b82e3fc42ba30 (diff)
Bug fixes for fold_constants_lib.
1. Tensor names in TF may be in the form of "a:0", "a:1", or "a" as a shorthand notation of "a:0". FoldConstant library always expected the shorthand notation, and did not handle the cases where explicit notation was passed to input or output list. This means that this library could not handle the case when input or output were not the first output of a node. 2. To match the input nodes in the original graph and the added Recv nodes in rewritten graph, FoldConstant library used prefix matching. Unfortunately, this means that when a input name is a prefix of another input name, there is possibility that wrong Recv node gets matched. For example, if input names were "placeholder" and "placeholder_1", then it did not handle the case very well. 3. RemoveUnusedNodes() in FoldConstants lib could remove nodes which output depended on. This happened when an input name points to a node with multiple outputs and not all outputs of that node were included in the input names. 4. ReplaceSendRecvs() in FoldConstants lib assumed that all input nodes are removed during rewriting the graph. This assumption is not necessarily true, and it could add a duplicate node in the graph. PiperOrigin-RevId: 172641947
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc202
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc85
2 files changed, 175 insertions, 112 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 30290c7a16..f2934a79bd 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -17,12 +17,20 @@ limitations under the License.
#include <algorithm>
#include <iterator>
+#include <map>
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <utility>
+#include <vector>
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/util/command_line_flags.h"
@@ -30,33 +38,38 @@ limitations under the License.
namespace tensorflow {
namespace graph_transforms {
+namespace {
+using StringPieceSet = std::unordered_set<StringPiece, StringPiece::Hasher>;
+template <typename T>
+using StringPieceMap = std::unordered_map<StringPiece, T, StringPiece::Hasher>;
+} // namespace
Status ReplaceSendRecvs(const GraphDef& original_graph_def,
const GraphDef& rewritten_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* output_graph_def) {
- std::map<string, const NodeDef*> original_map;
- MapNamesToNodes(original_graph_def, &original_map);
- std::map<string, string> new_node_names;
- for (const NodeDef& node : rewritten_graph_def.node()) {
- // If the op isn't a Recv, or it was in the original, nothing to do.
- if ((node.op() != "_Recv") || (original_map.count(node.name()) == 1)) {
- continue;
- }
- // See if it matches an input from the original.
- for (const string& input : inputs) {
- // Here we rely on the naming convention for the Recv nodes that
- // RewriteGraphForExecution adds in the place of the feed inputs.
- string input_prefix = "_recv_" + input + "_";
- if (StringPiece(node.name()).starts_with(input_prefix)) {
- // If it does, prepare to rename any inputs that refer to it.
- new_node_names[node.name()] = input;
- }
- }
+ // recv_node_names serves as a string storage for recv node names.
+ std::vector<string> recv_node_names(inputs.size());
+ StringPieceMap<TensorId> recv_node_map;
+ StringPieceSet input_nodes;
+ for (int i = 0; i < inputs.size(); ++i) {
+ // RewriteGraphForExecution adds a recv node for each input edge. We assume
+ // here that adding such recv node did not fail. For example, the original
+ // graph did not already have a node with the name for the new added recv
+ // node.
+ TensorId id = ParseTensorName(inputs[i]);
+ input_nodes.insert(id.first);
+ string& recv_node_name = recv_node_names[i];
+ recv_node_name = strings::StrCat("_recv_", id.first, "_", id.second);
+ recv_node_map.emplace(recv_node_name, id);
+ }
+
+ StringPieceMap<const NodeDef*> original_map;
+ for (const NodeDef& node : original_graph_def.node()) {
+ original_map.emplace(node.name(), &node);
}
- std::vector<NodeDef> nodes_to_add;
for (const NodeDef& node : rewritten_graph_def.node()) {
if ((node.op() == "_Send") || (node.op() == "_Recv")) {
// If the op is a Send or Recv that wasn't in the original, skip it.
@@ -64,55 +77,68 @@ Status ReplaceSendRecvs(const GraphDef& original_graph_def,
continue;
}
}
- NodeDef new_node;
- new_node = node;
- new_node.mutable_input()->Clear();
- for (const string& old_input : node.input()) {
- string input_prefix;
- string input_node_name;
- string input_suffix;
- NodeNamePartsFromInput(old_input, &input_prefix, &input_node_name,
- &input_suffix);
- string new_input;
- if (new_node_names.count(input_node_name) > 0) {
- new_input =
- input_prefix + new_node_names[input_node_name] + input_suffix;
- } else {
- new_input = old_input;
+
+ NodeDef* new_node = output_graph_def->add_node();
+ new_node->MergeFrom(node);
+ for (int i = 0; i < new_node->input_size(); ++i) {
+ string& input = *new_node->mutable_input(i);
+ TensorId id = ParseTensorName(input);
+ const auto iter = recv_node_map.find(id.first);
+ if (iter != recv_node_map.end()) {
+ // The node being substituted is a Recv node, and it has only one
+ // output. If this input is not a control input, then replace the input
+ // with the mapped value. Otherwise, replace the node name only.
+ if (id.second != Graph::kControlSlot) {
+ CHECK_EQ(id.second, 0);
+ input = iter->second.ToString();
+ } else {
+ id.first = iter->second.first;
+ input = id.ToString();
+ }
}
- *(new_node.mutable_input()->Add()) = new_input;
}
- nodes_to_add.push_back(new_node);
- }
- for (std::pair<string, string> entry : new_node_names) {
- string removed_node_name = entry.second;
- const NodeDef* removed_node = original_map[removed_node_name];
- NodeDef new_node;
- new_node = *removed_node;
- nodes_to_add.push_back(new_node);
+
+ // RewriteGraphForExecution() did not remove this input node. Remove this
+ // node name from input_nodes so that a duplicate does not get added to the
+ // output_graph_def.
+ auto iter = input_nodes.find(new_node->name());
+ if (iter != input_nodes.end()) {
+ input_nodes.erase(iter);
+ }
}
- for (const NodeDef& node : nodes_to_add) {
- *output_graph_def->mutable_node()->Add() = node;
+ // Some input nodes are removed in rewrite_graph_def. Add those nodes to
+ // output_graph_def.
+ for (StringPiece name : input_nodes) {
+ const NodeDef& removed_node = *CHECK_NOTNULL(original_map[name]);
+ output_graph_def->add_node()->MergeFrom(removed_node);
}
+
return Status::OK();
}
Status RemoveUnusedNodes(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
- std::map<string, const NodeDef*> node_map;
- MapNamesToNodes(input_graph_def, &node_map);
+ StringPieceMap<const NodeDef*> node_map;
+ for (const NodeDef& node : input_graph_def.node()) {
+ node_map.emplace(node.name(), &node);
+ }
- std::set<string> used_nodes;
+ std::unordered_set<TensorId, TensorId::Hasher> input_names;
for (const string& input : context.input_names) {
- used_nodes.insert(input);
+ input_names.insert(ParseTensorName(input));
+ }
+ StringPieceSet used_nodes;
+ StringPieceSet current_nodes;
+ for (const string& name : context.output_names) {
+ TensorId id = ParseTensorName(name);
+ used_nodes.insert(id.first);
+ current_nodes.insert(id.first);
}
- std::vector<string> current_nodes = context.output_names;
while (!current_nodes.empty()) {
- std::set<string> next_nodes;
- for (const string& node_name : current_nodes) {
- used_nodes.insert(node_name);
+ StringPieceSet next_nodes;
+ for (StringPiece node_name : current_nodes) {
if (node_map.count(node_name) == 0) {
LOG(ERROR) << "Bad graph structure, no node named '" << node_name
<< "' found for input lookup";
@@ -120,14 +146,20 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
node_name, "' found for input lookup");
}
const NodeDef& node = *(node_map[node_name]);
- for (const string& input_name : node.input()) {
- const string& input_node_name = NodeNameFromInput(input_name);
- if (used_nodes.count(input_node_name) == 0) {
- next_nodes.insert(input_node_name);
+ for (const string& input : node.input()) {
+ TensorId id = ParseTensorName(input);
+ if (input_names.count(id) > 0) {
+ continue;
+ }
+ if (used_nodes.insert(id.first).second) {
+ next_nodes.insert(id.first);
}
}
}
- current_nodes = std::vector<string>(next_nodes.begin(), next_nodes.end());
+ current_nodes.swap(next_nodes);
+ }
+ for (const TensorId& id : input_names) {
+ used_nodes.insert(id.first);
}
FilterGraphDef(
input_graph_def,
@@ -141,7 +173,7 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle,
shape_inference::InferenceContext* context,
PartialTensorShape* shape) {
- // The default is already unknown
+ // The default is already unknown.
if (!context->RankKnown(handle)) return Status::OK();
std::vector<int64> dims(context->Rank(handle));
@@ -151,47 +183,6 @@ Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle,
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
}
-Status ShapeForNode(const TransformFuncContext& context,
- const string& node_name, TensorShape* result,
- bool* has_shape_specified) {
- *has_shape_specified = false;
-
- // Check to see if we have been given a default for all placeholders.
- if (context.params.count("type")) {
- if (context.params.at("shape").size() != 1) {
- return errors::InvalidArgument(
- "You must pass no more than one default 'shape' to "
- "fold_constants");
- }
- const string& shape_string = context.params.at("shape")[0];
- TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
- *has_shape_specified = true;
- }
-
- // See if there's a particular type specified for this placeholder.
- if (context.params.count("name") || context.params.count("type_for_name")) {
- if (!context.params.count("name") ||
- !context.params.count("type_for_name") ||
- (context.params.at("type_for_name").size() !=
- context.params.at("name").size())) {
- return errors::InvalidArgument(
- "You must pass a 'shape_for_name' arg for every 'name', e.g. "
- "fold_constants(name=foo, shape_for_name=\"2,2,1\", name=bar, "
- "shape_for_name=\"1\"");
- }
- const int name_count = context.params.at("name").size();
- for (int i = 0; i < name_count; ++i) {
- if (context.params.at("name")[i] == node_name) {
- const string& shape_string = context.params.at("shape_for_name")[i];
- TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
- *has_shape_specified = true;
- }
- }
- }
-
- return Status::OK();
-}
-
// Converts any sub-graphs that can be resolved into constant expressions into
// single Const ops.
Status FoldConstants(const GraphDef& input_graph_def,
@@ -215,17 +206,6 @@ Status FoldConstants(const GraphDef& input_graph_def,
GraphDef cleaned_graph_def;
RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def);
- // Set specified shapes.
- for (NodeDef& node : *cleaned_graph_def.mutable_node()) {
- TensorShape shape;
- bool has_shape_specified;
- TF_RETURN_IF_ERROR(
- ShapeForNode(context, node.name(), &shape, &has_shape_specified));
- if (has_shape_specified) {
- SetNodeAttr("shape", shape, &node);
- }
- }
-
TF_RETURN_IF_ERROR(
ImportGraphDef({}, cleaned_graph_def, &input_graph, &shape_refiner));
} else {
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index fd4188a6a4..41106de008 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -74,6 +74,9 @@ class ConstantFoldingTest : public ::testing::Test {
TestConstantFolding(graph_def,
{{"placeholder_expect_remains", placeholder_tensor}},
{}, {"output_expect_remains"}, {});
+ TestConstantFolding(graph_def,
+ {{"placeholder_expect_remains:0", placeholder_tensor}},
+ {}, {"output_expect_remains:0"}, {});
}
void TestOpExclusionAdd() {
@@ -256,10 +259,40 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(0, node_map.count("new_send"));
}
+ void TestReplaceSendRecvsPrefixNames() {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ auto o_root = tensorflow::Scope::NewRootScope();
+ auto a = Placeholder(o_root.WithOpName("placeholder"), DT_FLOAT);
+ auto b = Placeholder(o_root.WithOpName("placeholder_1"), DT_FLOAT);
+ auto add_o = Add(o_root.WithOpName("add"), a, b);
+ GraphDef o_graph_def;
+ TF_ASSERT_OK(o_root.ToGraphDef(&o_graph_def));
+
+ auto n_root = tensorflow::Scope::NewRootScope();
+ auto c = _Recv(n_root.WithOpName("_recv_placeholder_0"), DT_FLOAT, "", "",
+ 0, "");
+ auto d = _Recv(n_root.WithOpName("_recv_placeholder_1_0"), DT_FLOAT, "", "",
+ 0, "");
+ auto add_n = Add(n_root.WithOpName("add"), c, d);
+ GraphDef n_graph_def;
+ TF_ASSERT_OK(n_root.ToGraphDef(&n_graph_def));
+
+ GraphDef result_graph_def;
+ TF_ASSERT_OK(graph_transforms::ReplaceSendRecvs(
+ o_graph_def, n_graph_def, {"placeholder", "placeholder_1"}, {"add"},
+ &result_graph_def));
+
+ std::map<string, const NodeDef*> node_map;
+ graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
+ EXPECT_EQ(1, node_map.count("placeholder"));
+ EXPECT_EQ(1, node_map.count("placeholder_1"));
+ EXPECT_EQ(1, node_map.count("add"));
+ }
+
void TestRemoveUnusedNodes() {
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
auto root = tensorflow::Scope::NewRootScope();
- using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
const int width = 100;
@@ -295,6 +328,48 @@ class ConstantFoldingTest : public ::testing::Test {
EXPECT_EQ(1, node_map.count("output"));
EXPECT_EQ(0, node_map.count("unused"));
}
+
+ void TestRemoveUnusedNodesMultipleOutputs() {
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+ auto root = tensorflow::Scope::NewRootScope();
+
+ // a b
+ // \ /
+ // shape_n
+ // \ /
+ // c
+ auto a = Placeholder(root.WithOpName("a"), DT_FLOAT);
+ auto b = Placeholder(root.WithOpName("b"), DT_FLOAT);
+ auto shape_n = ShapeN(root.WithOpName("shape_n"), {Output(a), Output(b)});
+ auto c = Add(root.WithOpName("c"), shape_n[0], shape_n[1]);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+ GraphDef result_graph_def;
+ TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
+ graph_def, {{shape_n[0].name()}, {"c"}}, &result_graph_def));
+
+ // Only one output of shape_n node is fed input. Hence the graph search
+ // should propagate to inputs of shape_n. Nothing to remove here.
+ std::map<string, const NodeDef*> node_map;
+ graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
+ EXPECT_EQ(1, node_map.count("a"));
+ EXPECT_EQ(1, node_map.count("b"));
+ EXPECT_EQ(1, node_map.count("c"));
+
+ result_graph_def.Clear();
+ TF_ASSERT_OK(graph_transforms::RemoveUnusedNodes(
+ graph_def, {{shape_n[0].name(), shape_n[1].name()}, {"c"}},
+ &result_graph_def));
+
+ // Both outputs of shape_n node are fed inputs. shape_n does not function
+ // and inputs to shape_n should be removed.
+ node_map.clear();
+ graph_transforms::MapNamesToNodes(result_graph_def, &node_map);
+ EXPECT_EQ(0, node_map.count("a"));
+ EXPECT_EQ(0, node_map.count("b"));
+ EXPECT_EQ(1, node_map.count("c"));
+ }
};
TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); }
@@ -309,7 +384,15 @@ TEST_F(ConstantFoldingTest, TestPreserveOutputShapes) {
TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); }
+TEST_F(ConstantFoldingTest, TestReplaceSendRecvsPrefixNames) {
+ TestReplaceSendRecvsPrefixNames();
+}
+
TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }
+TEST_F(ConstantFoldingTest, TestRemoveUnusedNodesMultipleOutputs) {
+ TestRemoveUnusedNodesMultipleOutputs();
+}
+
} // namespace graph_transforms
} // namespace tensorflow