aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2016-11-08 17:47:43 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 18:03:59 -0800
commit7ea2f7d2689d9686c93650bf5bcc4c8ba459377d (patch)
tree6abadcdfd651c5f0134ac0f37bbc7901fa18ee16 /tensorflow/tools/graph_transforms
parentcadb43c37a1806dd617233fab40330927289c89a (diff)
Refactor fold_constants to make reuse easier.
Change: 138588267
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/BUILD17
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc8
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.cc13
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.h6
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils_test.cc21
5 files changed, 61 insertions, 4 deletions
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index d0aad58222..02500cccdc 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -92,11 +92,12 @@ tf_cc_test(
],
)
-cc_binary(
- name = "fold_constants_tool",
+# This library includes a main function, to make it easy to create other
+# versions of the tool linked against different operator libs.
+cc_library(
+ name = "fold_constants_main_lib",
srcs = ["fold_constants_tool.cc"],
copts = tf_copts(),
- linkstatic = 1,
visibility = ["//visibility:public"],
deps = [
":fold_constants_lib",
@@ -104,3 +105,13 @@ cc_binary(
"//tensorflow/core:lib",
],
)
+
+cc_binary(
+ name = "fold_constants_tool",
+ copts = tf_copts(),
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+ deps = [
+ ":fold_constants_main_lib",
+ ],
+)
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 01e758da81..373d237f08 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -19,7 +19,9 @@ limitations under the License.
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
+#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/util/command_line_flags.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
@@ -136,10 +138,14 @@ Status FoldConstants(const GraphDef& input_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* output_graph_def) {
+ // Some older GraphDefs have saved _output_shapes attributes which are out of
+ // date and cause import errors, so clean them up first.
+ GraphDef cleaned_graph_def;
+ RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def);
Graph input_graph(OpRegistry::Global());
ImportGraphDefOptions import_opts;
TF_RETURN_IF_ERROR(
- ImportGraphDef(import_opts, input_graph_def, &input_graph, nullptr));
+ ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr));
DeviceAttributes device_attributes;
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
&input_graph, inputs, outputs, {}, device_attributes));
diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc
index b3e0093afd..72664eee9b 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils.cc
@@ -64,5 +64,18 @@ void FilterGraphDef(const GraphDef& input_graph_def,
}
}
+void RemoveAttributes(const GraphDef& input_graph_def,
+ const std::vector<string>& attributes,
+ GraphDef* output_graph_def) {
+ output_graph_def->mutable_node()->Clear();
+ for (const NodeDef& node : input_graph_def.node()) {
+ NodeDef* new_node = output_graph_def->mutable_node()->Add();
+ new_node->CopyFrom(node);
+ for (const string& attribute : attributes) {
+ new_node->mutable_attr()->erase(attribute);
+ }
+ }
+}
+
} // namespace graph_transforms
} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h
index 8f6433fa34..7fb885b1ac 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.h
+++ b/tensorflow/tools/graph_transforms/transform_utils.h
@@ -45,6 +45,12 @@ void FilterGraphDef(const GraphDef& input_graph_def,
std::function<bool(const NodeDef&)> selector,
GraphDef* output_graph_def);
+// Creates a copy of the input graph, with all occurences of the attributes with
+// the names in the argument removed from the node defs.
+void RemoveAttributes(const GraphDef& input_graph_def,
+ const std::vector<string>& attributes,
+ GraphDef* output_graph_def);
+
} // namespace graph_transforms
} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc
index 25dfb015d6..1c1f4d97ed 100644
--- a/tensorflow/tools/graph_transforms/transform_utils_test.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc
@@ -141,6 +141,25 @@ class TransformUtilsTest : public ::testing::Test {
EXPECT_EQ(1, node_map.count("output"));
EXPECT_EQ(0, node_map.count("remove_me"));
}
+
+ void TestRemoveAttributes() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+ GraphDef result_graph_def;
+ RemoveAttributes(graph_def, {"dtype"}, &result_graph_def);
+
+ std::map<string, const NodeDef*> node_map;
+ MapNamesToNodes(result_graph_def, &node_map);
+ const NodeDef* removed_placeholder = node_map["placeholder"];
+ EXPECT_EQ(nullptr,
+ tensorflow::AttrSlice(*removed_placeholder).Find("dtype"));
+ }
};
TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); }
@@ -153,5 +172,7 @@ TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); }
TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); }
+TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); }
+
} // namespace graph_transforms
} // namespace tensorflow