aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-31 09:53:17 -0800
committerGravatar Michael Case <mikecase@google.com>2018-01-31 17:08:10 -0800
commit16b9fc676be6f8aacf06977a7f9439a56ffccefa (patch)
treede2890c9997059bffd67caf6d24d9adfe31d0aff /tensorflow/tools/graph_transforms
parent76989a191815bdd96390626db154676ac42b890d (diff)
Extending sparsify_gather to remove variables from the tensorflow summaries.
PiperOrigin-RevId: 184004859
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather.cc80
-rw-r--r--tensorflow/tools/graph_transforms/sparsify_gather_test.cc86
2 files changed, 131 insertions, 35 deletions
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather.cc b/tensorflow/tools/graph_transforms/sparsify_gather.cc
index 593c654f9f..9c583d83ca 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather.cc
@@ -181,6 +181,14 @@ Status ObtainVariableInfo(
return Status::OK();
}
+Status RemoveInputAtIndex(NodeDef* n, int index) {
+ for (int i = index; i < n->input_size() - 1; i++) {
+ n->mutable_input()->SwapElements(i, i + 1);
+ }
+ n->mutable_input()->RemoveLast();
+ return Status::OK();
+}
+
Status SparsifyGatherInternal(
const GraphDef& input_graph_def,
const std::unique_ptr<std::unordered_map<string, string> >&
@@ -301,13 +309,13 @@ Status SparsifyGatherInternal(
TF_RETURN_IF_ERROR(ReadTensorFromCheckpoint(
weights_node.name(), ckpt_reader,
(*shapes_and_slices)[weights_node.name()], &weight));
- // Add both both weight and identity node names.
- removed_node_names.push_back(weights_node.name());
- removed_node_names.push_back(match.inputs[0].node.name());
- for (auto input_node : match.inputs[0].node.input()) {
- auto parsed_input = StringReplace(input_node, "^", "", true);
- refs[parsed_input]--;
- }
+ }
+ // Add both both weight and identity node names.
+ removed_node_names.push_back(weights_node.name());
+ removed_node_names.push_back(match.inputs[0].node.name());
+ for (auto input_node : match.inputs[0].node.input()) {
+ auto parsed_input = StringReplace(input_node, "^", "", true);
+ refs[parsed_input]--;
}
Tensor indices_tensor;
Tensor values_tensor;
@@ -468,26 +476,49 @@ Status SparsifyGatherInternal(
continue;
}
int j = 0;
+ bool deleted_inputs = false;
while (j < replaced_graph_def.node(i).input_size()) {
if (replaced_graph_def.node(i).input(j) == name ||
replaced_graph_def.node(i).input(j) == ("^" + name)) {
- replaced_graph_def.mutable_node(i)->mutable_input()->SwapElements(
- j, replaced_graph_def.node(i).input_size() - 1);
- replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
+ TF_RETURN_IF_ERROR(
+ RemoveInputAtIndex(replaced_graph_def.mutable_node(i), j));
+ deleted_inputs = true;
continue;
}
j++;
}
- if (!replaced_graph_def.node(i).input_size()) {
- if ((refs.find(replaced_graph_def.node(i).name()) != refs.end()) &&
- (refs[replaced_graph_def.node(i).name()] == 0)) {
+ if (deleted_inputs) {
+ if (replaced_graph_def.node(i).op() == "ConcatV2") {
+ if (replaced_graph_def.node(i).input_size() > 2) {
+ SetNodeAttr("N", replaced_graph_def.node(i).input_size() - 1,
+ replaced_graph_def.mutable_node(i));
+ } else if (replaced_graph_def.node(i).input_size() == 2) {
+ if (refs[replaced_graph_def.node(i).input(1)] != 1) {
+ return errors::Internal(
+ "Expect axis tensor of ConcatV2 node to only be referenced "
+ "once.");
+ }
+ refs[replaced_graph_def.node(i).input(1)] -= 1;
+ removed_node_names.push_back(replaced_graph_def.node(i).input(1));
+ replaced_graph_def.mutable_node(i)->mutable_input()->RemoveLast();
+ replaced_graph_def.mutable_node(i)->mutable_attr()->erase("N");
+ replaced_graph_def.mutable_node(i)->set_op("Identity");
+ } else {
+ return errors::Internal(
+ "ConcatV2 should have at least two elements");
+ }
+ }
+ if ((replaced_graph_def.node(i).op() == "Assign" ||
+ replaced_graph_def.node(i).op() == "Reshape" ||
+ replaced_graph_def.node(i).op() == "Equal" ||
+ replaced_graph_def.node(i).op() == "Mean" ||
+ replaced_graph_def.node(i).op() == "ScalarSummary") &&
+ replaced_graph_def.node(i).input_size() == 1) {
+ removed_node_names.push_back(replaced_graph_def.node(i).name());
+ }
+ if (!replaced_graph_def.node(i).input_size()) {
removed_node_names.push_back(replaced_graph_def.node(i).name());
}
- }
-
- if (replaced_graph_def.node(i).op() == "Assign" &&
- replaced_graph_def.node(i).input_size() == 1) {
- removed_node_names.push_back(replaced_graph_def.node(i).name());
}
i++;
}
@@ -528,17 +559,22 @@ Status SparsifyGather(const GraphDef& input_graph_def,
};
// clang-format on
+ GraphDef cleaned_input_graph_def;
+ RemoveAttributes(input_graph_def, {"_output_shapes"},
+ &cleaned_input_graph_def);
+
GraphDef temp_output;
std::unique_ptr<BundleReader> ckpt_reader;
TF_RETURN_IF_ERROR(InitializeCheckpointReader(context, &ckpt_reader));
std::unique_ptr<std::unordered_map<string, string> > shapes_and_slices;
- TF_RETURN_IF_ERROR(ObtainVariableInfo(input_graph_def, &shapes_and_slices));
+ TF_RETURN_IF_ERROR(
+ ObtainVariableInfo(cleaned_input_graph_def, &shapes_and_slices));
- TF_RETURN_IF_ERROR(SparsifyGatherInternal(input_graph_def, shapes_and_slices,
- context, gather_pattern,
- ckpt_reader, &temp_output));
+ TF_RETURN_IF_ERROR(SparsifyGatherInternal(
+ cleaned_input_graph_def, shapes_and_slices, context, gather_pattern,
+ ckpt_reader, &temp_output));
TF_RETURN_IF_ERROR(SparsifyGatherInternal(temp_output, shapes_and_slices,
context, gather_v2_pattern,
diff --git a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
index 6627df1331..203ed3e0f9 100644
--- a/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
+++ b/tensorflow/tools/graph_transforms/sparsify_gather_test.cc
@@ -71,7 +71,7 @@ class SparsifyGatherTest : public ::testing::Test {
}
void TestSinglePartition(bool gather_v2, bool include_shared_init,
- bool test_variable,
+ bool test_variable, bool test_kept_concat,
const string& shared_init_name = "group_deps") {
GraphDef graph_def;
@@ -139,6 +139,26 @@ class SparsifyGatherTest : public ::testing::Test {
}
}
+ NodeDef* concat_axis_node =
+ CreateNode("linear/concat/axis", "Const", {}, &graph_def);
+ NodeDef* concat_input_node =
+ CreateNode("concat/input/node", "Const", {}, &graph_def);
+ NodeDef* concat_node = nullptr;
+ if (!test_kept_concat) {
+ concat_node = CreateNode(
+ "concat/node", "ConcatV2",
+ {identity_node, concat_input_node, concat_axis_node}, &graph_def);
+ SetNodeAttr("N", 2, concat_node);
+ } else {
+ NodeDef* concat_input_node_2 =
+ CreateNode("concat/input/node_2", "Const", {}, &graph_def);
+ concat_node = CreateNode("concat/node", "ConcatV2",
+ {identity_node, concat_input_node,
+ concat_input_node_2, concat_axis_node},
+ &graph_def);
+ SetNodeAttr("N", 3, concat_node);
+ }
+
// Run the op.
GraphDef result;
TransformFuncContext context;
@@ -166,6 +186,23 @@ class SparsifyGatherTest : public ::testing::Test {
EXPECT_EQ(1, node_lookup.count("ids"));
EXPECT_EQ("Const", node_lookup.at("ids")->op());
+ EXPECT_EQ(1, node_lookup.count("concat/node"));
+
+ if (!test_kept_concat) {
+ EXPECT_EQ(0, node_lookup.count("linear/concat/axis"));
+ EXPECT_EQ("Identity", node_lookup.at("concat/node")->op());
+ EXPECT_EQ(1, node_lookup.at("concat/node")->input_size());
+ EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0));
+ } else {
+ EXPECT_EQ(1, node_lookup.count("linear/concat/axis"));
+ EXPECT_EQ("ConcatV2", node_lookup.at("concat/node")->op());
+ EXPECT_EQ(3, node_lookup.at("concat/node")->input_size());
+ EXPECT_EQ("concat/input/node", node_lookup.at("concat/node")->input(0));
+ EXPECT_EQ("concat/input/node_2", node_lookup.at("concat/node")->input(1));
+ EXPECT_EQ("linear/concat/axis", node_lookup.at("concat/node")->input(2));
+ EXPECT_EQ(2, node_lookup.at("concat/node")->attr().at("N").i());
+ }
+
EXPECT_EQ(1, node_lookup.count("w/part_1/indices"));
EXPECT_EQ("Const", node_lookup.at("w/part_1/indices")->op());
Tensor expected_indices_tensor(DT_INT64, TensorShape({3}));
@@ -344,6 +381,13 @@ class SparsifyGatherTest : public ::testing::Test {
MakeGather("gather1", gather_v2, identity_node1, input_node, &graph_def);
MakeGather("gather2", gather_v2, identity_node2, input_node, &graph_def);
+ NodeDef* concat_axis_node =
+ CreateNode("linear/concat/axis", "Const", {}, &graph_def);
+ NodeDef* concat_node = CreateNode(
+ "concat/node", "ConcatV2",
+ {identity_node1, identity_node2, concat_axis_node}, &graph_def);
+ SetNodeAttr("N", 2, concat_node);
+
// Shared init node
if (include_shared_init) {
if (!test_variable) {
@@ -515,6 +559,9 @@ class SparsifyGatherTest : public ::testing::Test {
node_lookup.at("gather2/LookupTableFind")->input(2));
EXPECT_EQ("gather2/LookupTableFind", node_lookup.at("gather2")->input(0));
+ EXPECT_EQ(0, node_lookup.count("linear/concat/axis"));
+ EXPECT_EQ(0, node_lookup.count("concat/node"));
+
// Check control deps.
EXPECT_EQ(2, node_lookup.at(shared_init_name)->input_size());
EXPECT_NE(std::find(node_lookup.at(shared_init_name)->input().begin(),
@@ -550,18 +597,31 @@ class SparsifyGatherTest : public ::testing::Test {
};
TEST_F(SparsifyGatherTest, TestSinglePartition) {
- TestSinglePartition(false, false, false);
- TestSinglePartition(false, true, false);
- TestSinglePartition(true, false, false);
- TestSinglePartition(true, true, false);
- TestSinglePartition(false, false, true);
- TestSinglePartition(false, true, true);
- TestSinglePartition(true, false, true);
- TestSinglePartition(true, true, true);
- TestSinglePartition(false, true, false, "shared_inits");
- TestSinglePartition(true, true, false, "shared_inits");
- TestSinglePartition(false, true, true, "shared_inits");
- TestSinglePartition(true, true, true, "shared_inits");
+ TestSinglePartition(false, false, false, false);
+ TestSinglePartition(false, true, false, false);
+ TestSinglePartition(true, false, false, false);
+ TestSinglePartition(true, true, false, false);
+ TestSinglePartition(false, false, true, false);
+ TestSinglePartition(false, true, true, false);
+ TestSinglePartition(true, false, true, false);
+ TestSinglePartition(true, true, true, false);
+ TestSinglePartition(false, true, false, false, "shared_inits");
+ TestSinglePartition(true, true, false, false, "shared_inits");
+ TestSinglePartition(false, true, true, false, "shared_inits");
+ TestSinglePartition(true, true, true, false, "shared_inits");
+
+ TestSinglePartition(false, false, false, true);
+ TestSinglePartition(false, true, false, true);
+ TestSinglePartition(true, false, false, true);
+ TestSinglePartition(true, true, false, true);
+ TestSinglePartition(false, false, true, true);
+ TestSinglePartition(false, true, true, true);
+ TestSinglePartition(true, false, true, true);
+ TestSinglePartition(true, true, true, true);
+ TestSinglePartition(false, true, false, true, "shared_inits");
+ TestSinglePartition(true, true, false, true, "shared_inits");
+ TestSinglePartition(false, true, true, true, "shared_inits");
+ TestSinglePartition(true, true, true, true, "shared_inits");
}
TEST_F(SparsifyGatherTest, TestMultiPartition) {