aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-26 09:47:58 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-26 09:55:27 -0800
commitc1338b14149b6313280bea455ec1dec2a336bd31 (patch)
treea6c458d25f3968b7f8d53f3c5f045b0750c05103 /tensorflow/tools/graph_transforms
parentc4ace4e2abf6f19f34357e53ba4aebce5113af01 (diff)
Updating sparsify_gather.
PiperOrigin-RevId: 183402917
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather.cc109
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather_test.cc40
2 files changed, 110 insertions, 39 deletions
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 96324d0dea..593c654f9f 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <cmath>
#include <memory>
+#include <unordered_map>
#include "tensorflow/c/checkpoint_reader.h"
#include "tensorflow/core/framework/tensor.h"
@@ -28,9 +29,10 @@ limitations under the License.
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
-using strings::StrCat;
using str_util::Join;
using str_util::Split;
+using str_util::StringReplace;
+using strings::StrCat;
namespace graph_transforms {
@@ -89,7 +91,7 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def,
string* shape_slice_string) {
string restore_node_name;
for (const auto& node : input_graph_def.node()) {
- std::vector<string> node_name_parts = str_util::Split(node.name(), "/");
+ std::vector<string> node_name_parts = Split(node.name(), "/");
if (node_name_parts.size() == 2 &&
StringPiece(node_name_parts[0]).starts_with("save") &&
StringPiece(node_name_parts[1]).starts_with("Assign") &&
@@ -119,13 +121,13 @@ Status ObtainTensorSlice(const GraphDef& input_graph_def,
}
string GetMonolithicTensorKey(const string& tensor_slice_name) {
- std::vector<string> names = str_util::Split(tensor_slice_name, "/");
+ std::vector<string> names = Split(tensor_slice_name, "/");
CHECK_GE(names.size(), 2);
CHECK(StringPiece(names[names.size() - 1]).starts_with("part_"));
// Remove the "part_x" suffix
names.pop_back();
- return str_util::Join(names, "/");
+ return Join(names, "/");
}
Status ReadTensorFromCheckpoint(
@@ -193,6 +195,15 @@ Status SparsifyGatherInternal(
GraphDef current_graph_def = input_graph_def;
bool any_match_found = false;
+ // Populate references.
+ std::unordered_map<string, int> refs;
+ for (const auto& node : current_graph_def.node()) {
+ for (const auto& input : node.input()) {
+ auto parsed_input = StringReplace(input, "^", "", true);
+ refs[parsed_input] += 1;
+ }
+ }
+
// The subgraphs may have overlapping components, therefore GraphMatcher
// doesn't return all subgraphs in one round -- this has to be multi-round
// update.
@@ -200,15 +211,15 @@ Status SparsifyGatherInternal(
any_match_found = false;
GraphDef replaced_graph_def = current_graph_def;
std::vector<string> init_table_node_names;
- std::vector<string> removed_variable_names;
+ std::vector<string> removed_node_names;
TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
current_graph_def, pattern,
[&ckpt_reader, &any_match_found, &init_table_node_names,
- &shapes_and_slices, &removed_variable_names](
- const NodeMatch& match, const std::set<string>& input_nodes,
- const std::set<string>& output_nodes,
- std::vector<NodeDef>* new_nodes) {
+ &shapes_and_slices, &removed_node_names,
+ &refs](const NodeMatch& match, const std::set<string>& input_nodes,
+ const std::set<string>& output_nodes,
+ std::vector<NodeDef>* new_nodes) {
any_match_found = true;
// The captured subgraph should be of the following pattern:
@@ -291,8 +302,12 @@ Status SparsifyGatherInternal(
weights_node.name(), ckpt_reader,
(*shapes_and_slices)[weights_node.name()], &weight));
// Add both both weight and identity node names.
- removed_variable_names.push_back(weights_node.name());
- removed_variable_names.push_back(match.inputs[0].node.name());
+ removed_node_names.push_back(weights_node.name());
+ removed_node_names.push_back(match.inputs[0].node.name());
+ for (auto input_node : match.inputs[0].node.input()) {
+ auto parsed_input = StringReplace(input_node, "^", "", true);
+ refs[parsed_input]--;
+ }
}
Tensor indices_tensor;
Tensor values_tensor;
@@ -362,15 +377,23 @@ Status SparsifyGatherInternal(
// Connect nodes
AddNodeInput(hashtable_node.name(), &init_table_node);
+ refs[hashtable_node.name()]++;
AddNodeInput(indices_node.name(), &init_table_node);
+ refs[indices_node.name()]++;
AddNodeInput(values_node.name(), &init_table_node);
+ refs[values_node.name()]++;
AddNodeInput(hashtable_node.name(), &lookup_node);
+ refs[hashtable_node.name()]++;
AddNodeInput(gather_node.input(1), &lookup_node);
+ refs[gather_node.input(1)]++;
AddNodeInput(default_value_node.name(), &lookup_node);
+ refs[default_value_node.name()]++;
AddNodeInput(lookup_node.name(), &expand_dims_node);
+ refs[lookup_node.name()]++;
AddNodeInput(dim_idx_node.name(), &expand_dims_node);
+ refs[dim_idx_node.name()]++;
// Copy 'ids' input of original 'Gather'
new_nodes->push_back(match.inputs[1].node);
@@ -404,22 +427,44 @@ Status SparsifyGatherInternal(
for (const string& name : init_table_node_names) {
// Add control dependence from init_table_node to group_deps_node
AddNodeInput(StrCat("^", name), init_op);
+ refs[name]++;
+ }
+
+ // Erase inputs and outputs as they are not considered for deletion.
+ for (const auto& output : context.output_names) {
+ refs.erase(output);
+ }
+
+ for (const auto& input : context.input_names) {
+ refs.erase(input);
}
- // Remove all dependencies associated with removed variables.
- while (!removed_variable_names.empty()) {
- auto name = removed_variable_names.back();
- removed_variable_names.pop_back();
+ // Add nodes with a reference count of 0 for deletion.
+ for (auto entry : refs) {
+ if (entry.second == 0) {
+ removed_node_names.push_back(entry.first);
+ }
+ }
+
+ while (!removed_node_names.empty()) {
+ auto name = removed_node_names.back();
+ removed_node_names.pop_back();
+
int i = 0;
while (i < replaced_graph_def.node_size()) {
- if (!replaced_graph_def.node(i).input_size()) {
- if (replaced_graph_def.node(i).name() == name) {
- replaced_graph_def.mutable_node()->SwapElements(
- i, replaced_graph_def.node_size() - 1);
- replaced_graph_def.mutable_node()->RemoveLast();
- continue;
+ // Revisit this to see if we can safely remove RestoreV2 nodes.
+ if ((replaced_graph_def.node(i).name() == name) &&
+ (replaced_graph_def.node(i).op() != "RestoreV2")) {
+ for (const auto& input : replaced_graph_def.node(i).input()) {
+ auto parsed_input = StringReplace(input, "^", "", true);
+ refs[parsed_input] -= 1;
+ if (refs[parsed_input] == 0) {
+ removed_node_names.push_back(parsed_input);
+ }
}
- i++;
+ replaced_graph_def.mutable_node()->SwapElements(
+ i, replaced_graph_def.node_size() - 1);
+ replaced_graph_def.mutable_node()->RemoveLast();
continue;
}
int j = 0;
@@ -433,18 +478,16 @@ Status SparsifyGatherInternal(
}
j++;
}
- if ((replaced_graph_def.node(i).input_size() == 0) ||
- (replaced_graph_def.node(i).op() == "Assign" &&
- replaced_graph_def.node(i).input_size() == 1)) {
- removed_variable_names.push_back(replaced_graph_def.node(i).name());
- if (replaced_graph_def.node(i).input_size() == 1) {
- removed_variable_names.push_back(
- replaced_graph_def.node(i).input(0));
+ if (!replaced_graph_def.node(i).input_size()) {
+ if ((refs.find(replaced_graph_def.node(i).name()) != refs.end()) &&
+ (refs[replaced_graph_def.node(i).name()] == 0)) {
+ removed_node_names.push_back(replaced_graph_def.node(i).name());
}
- replaced_graph_def.mutable_node()->SwapElements(
- i, replaced_graph_def.node_size() - 1);
- replaced_graph_def.mutable_node()->RemoveLast();
- continue;
+ }
+
+ if (replaced_graph_def.node(i).op() == "Assign" &&
+ replaced_graph_def.node(i).input_size() == 1) {
+ removed_node_names.push_back(replaced_graph_def.node(i).name());
}
i++;
}
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
index 000568a0cc..6627df1331 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
@@ -80,6 +80,8 @@ class SparsifyGatherTest : public ::testing::Test {
// Build the graph.
NodeDef* input_node = CreateNode("ids", "Const", {}, &graph_def);
NodeDef* w_node;
+ NodeDef* zeros_const;
+ NodeDef* zeros_shape;
NodeDef* zeros_node;
NodeDef* assign_node;
@@ -92,8 +94,12 @@ class SparsifyGatherTest : public ::testing::Test {
} else {
w_node = CreateNode("w/part_1", "VariableV2", {}, &graph_def);
- zeros_node =
- CreateNode("w/part_1/Initializer/zeros", "Const", {}, &graph_def);
+ zeros_shape = CreateNode("w/part_1/Initializer/zeros/shape_as_tensor",
+ "Const", {}, &graph_def);
+ zeros_const = CreateNode("w/part_1/Initializer/zeros/Const", "Const", {},
+ &graph_def);
+ zeros_node = CreateNode("w/part_1/Initializer/zeros", "Fill",
+ {zeros_shape, zeros_const}, &graph_def);
assign_node = CreateNode("w/part_1/Assign", "Assign",
{w_node, zeros_node}, &graph_def);
@@ -151,6 +157,9 @@ class SparsifyGatherTest : public ::testing::Test {
MapNamesToNodes(result, &node_lookup);
// Check nodes.
+ EXPECT_EQ(0,
+ node_lookup.count("w/part_1/Initializer/zeros/shape_as_tensor"));
+ EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros/Const"));
EXPECT_EQ(0, node_lookup.count("w/part_1/Initializer/zeros"));
EXPECT_EQ(0, node_lookup.count("w/part_1/Assign"));
@@ -247,7 +256,11 @@ class SparsifyGatherTest : public ::testing::Test {
// Two partitions
NodeDef* w_node1;
NodeDef* w_node2;
+ NodeDef* zeros_const1;
+ NodeDef* zeros_shape1;
NodeDef* zeros_node1;
+ NodeDef* zeros_const2;
+ NodeDef* zeros_shape2;
NodeDef* zeros_node2;
NodeDef* assign_node1;
NodeDef* assign_node2;
@@ -261,8 +274,13 @@ class SparsifyGatherTest : public ::testing::Test {
SetNodeTensorAttr<float>("value", weights, w_node2);
} else {
w_node1 = CreateNode("w1/part_1", "VariableV2", {}, &graph_def);
- zeros_node1 =
- CreateNode("w1/part_1/Initializer/zeros", "Const", {}, &graph_def);
+
+ zeros_shape1 = CreateNode("w1/part_1/Initializer/zeros/shape_as_tensor",
+ "Const", {}, &graph_def);
+ zeros_const1 = CreateNode("w1/part_1/Initializer/zeros/Const", "Const",
+ {}, &graph_def);
+ zeros_node1 = CreateNode("w1/part_1/Initializer/zeros", "Fill",
+ {zeros_shape1, zeros_const1}, &graph_def);
assign_node1 = CreateNode("w1/part_1/Assign", "Assign",
{w_node1, zeros_node1}, &graph_def);
@@ -285,8 +303,12 @@ class SparsifyGatherTest : public ::testing::Test {
CreateNode("save/Assign", "Assign", {w_node1, restore_node1}, &graph_def);
w_node2 = CreateNode("w2/part_1", "VariableV2", {}, &graph_def);
- zeros_node2 =
- CreateNode("w2/part_1/Initializer/zeros", "Const", {}, &graph_def);
+ zeros_shape2 = CreateNode("w2/part_1/Initializer/zeros/shape_as_tensor",
+ "Const", {}, &graph_def);
+ zeros_const2 = CreateNode("w2/part_1/Initializer/zeros/Const", "Const",
+ {}, &graph_def);
+ zeros_node2 = CreateNode("w2/part_1/Initializer/zeros", "Fill",
+ {zeros_shape2, zeros_const2}, &graph_def);
assign_node2 = CreateNode("w2/part_1/Assign", "Assign",
{w_node2, zeros_node2}, &graph_def);
@@ -350,8 +372,14 @@ class SparsifyGatherTest : public ::testing::Test {
MapNamesToNodes(result, &node_lookup);
// Check nodes.
+ EXPECT_EQ(0,
+ node_lookup.count("w1/part_1/Initializer/zeros/shape_as_tensor"));
+ EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros/Const"));
EXPECT_EQ(0, node_lookup.count("w1/part_1/Initializer/zeros"));
EXPECT_EQ(0, node_lookup.count("w1/part_1/Assign"));
+ EXPECT_EQ(0,
+ node_lookup.count("w2/part_1/Initializer/zeros/shape_as_tensor"));
+ EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros/Const"));
EXPECT_EQ(0, node_lookup.count("w2/part_1/Initializer/zeros"));
EXPECT_EQ(0, node_lookup.count("w2/part_1/Assign"));
EXPECT_EQ(1, node_lookup.count("ids"));