diff options
author | 2016-11-08 17:47:43 -0800 | |
---|---|---|
committer | 2016-11-08 18:03:59 -0800 | |
commit | 7ea2f7d2689d9686c93650bf5bcc4c8ba459377d (patch) | |
tree | 6abadcdfd651c5f0134ac0f37bbc7901fa18ee16 /tensorflow/tools/graph_transforms | |
parent | cadb43c37a1806dd617233fab40330927289c89a (diff) |
Refactor fold_constants to make reuse easier.
Change: 138588267
Diffstat (limited to 'tensorflow/tools/graph_transforms')
5 files changed, 61 insertions, 4 deletions
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index d0aad58222..02500cccdc 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -92,11 +92,12 @@ tf_cc_test( ], ) -cc_binary( - name = "fold_constants_tool", +# This library includes a main function, to make it easy to create other +# versions of the tool linked against different operator libs. +cc_library( + name = "fold_constants_main_lib", srcs = ["fold_constants_tool.cc"], copts = tf_copts(), - linkstatic = 1, visibility = ["//visibility:public"], deps = [ ":fold_constants_lib", @@ -104,3 +105,13 @@ cc_binary( "//tensorflow/core:lib", ], ) + +cc_binary( + name = "fold_constants_tool", + copts = tf_copts(), + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":fold_constants_main_lib", + ], +) diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 01e758da81..373d237f08 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -19,7 +19,9 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" +#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/public/session.h" +#include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { @@ -136,10 +138,14 @@ Status FoldConstants(const GraphDef& input_graph_def, const std::vector<string>& inputs, const std::vector<string>& outputs, GraphDef* output_graph_def) { + // Some older GraphDefs have saved _output_shapes attributes which are out of + // date and cause import errors, so clean them up first. + GraphDef cleaned_graph_def; + RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def); Graph input_graph(OpRegistry::Global()); ImportGraphDefOptions import_opts; TF_RETURN_IF_ERROR( - ImportGraphDef(import_opts, input_graph_def, &input_graph, nullptr)); + ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr)); DeviceAttributes device_attributes; TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( &input_graph, inputs, outputs, {}, device_attributes)); diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index b3e0093afd..72664eee9b 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -64,5 +64,18 @@ void FilterGraphDef(const GraphDef& input_graph_def, } } +void RemoveAttributes(const GraphDef& input_graph_def, + const std::vector<string>& attributes, + GraphDef* output_graph_def) { + output_graph_def->mutable_node()->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + new_node->CopyFrom(node); + for (const string& attribute : attributes) { + new_node->mutable_attr()->erase(attribute); + } + } +} + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index 8f6433fa34..7fb885b1ac 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -45,6 +45,12 @@ void FilterGraphDef(const GraphDef& input_graph_def, std::function<bool(const NodeDef&)> selector, GraphDef* output_graph_def); +// Creates a copy of the input graph, with all occurences of the attributes with +// the names in the argument removed from the node defs. +void RemoveAttributes(const GraphDef& input_graph_def, + const std::vector<string>& attributes, + GraphDef* output_graph_def); + } // namespace graph_transforms } // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/transform_utils_test.cc b/tensorflow/tools/graph_transforms/transform_utils_test.cc index 25dfb015d6..1c1f4d97ed 100644 --- a/tensorflow/tools/graph_transforms/transform_utils_test.cc +++ b/tensorflow/tools/graph_transforms/transform_utils_test.cc @@ -141,6 +141,25 @@ class TransformUtilsTest : public ::testing::Test { EXPECT_EQ(1, node_map.count("output")); EXPECT_EQ(0, node_map.count("remove_me")); } + + void TestRemoveAttributes() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Output placeholder = Placeholder(root.WithOpName("placeholder"), DT_FLOAT); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + GraphDef result_graph_def; + RemoveAttributes(graph_def, {"dtype"}, &result_graph_def); + + std::map<string, const NodeDef*> node_map; + MapNamesToNodes(result_graph_def, &node_map); + const NodeDef* removed_placeholder = node_map["placeholder"]; + EXPECT_EQ(nullptr, + tensorflow::AttrSlice(*removed_placeholder).Find("dtype")); + } }; TEST_F(TransformUtilsTest, TestMapNamesToNodes) { TestMapNamesToNodes(); } @@ -153,5 +172,7 @@ TEST_F(TransformUtilsTest, TestNodeNameFromInput) { TestNodeNameFromInput(); } TEST_F(TransformUtilsTest, TestFilterGraphDef) { TestFilterGraphDef(); } +TEST_F(TransformUtilsTest, TestRemoveAttributes) { TestRemoveAttributes(); } + } // namespace graph_transforms } // namespace tensorflow |