diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-12 22:21:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-12 22:25:30 -0700 |
commit | a06b378194780c30ee695e9fe9a5b77aaf8bf1f4 (patch) | |
tree | 2bf138b84c9fed6a8dd9dafce21e1840ed68fdb4 /tensorflow/tools/graph_transforms | |
parent | 99dc61dbe520b43fcc1919124d2281d3c4fdfa85 (diff) |
Add "clear_output_shapes" option to FoldConstants transformer in
tools/graph_transforms.
By setting this option to false, the transformer will not strip off the shape
information stored as attributes.
PiperOrigin-RevId: 172057283
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r-- | tensorflow/tools/graph_transforms/README.md | 7 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_constants_lib.cc | 108 | ||||
-rw-r--r-- | tensorflow/tools/graph_transforms/fold_constants_test.cc | 48 |
3 files changed, 126 insertions, 37 deletions
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index 00297f07b7..c7f7eca257 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -385,7 +385,12 @@ input is collapsed down into a simple constant. ### fold_constants -Args: None \ +Args: + +* clear_output_shapes: Clears tensor shape information saved as attributes. + Some older graphs containes out-of-date information and may cause import + errors. Defaults to true. + Prerequisites: None Looks for any sub-graphs within the model that always evaluate to constant diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index 0f5bc2bcdd..30290c7a16 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" +#include <algorithm> +#include <iterator> + #include "tensorflow/core/common_runtime/constant_folding.h" #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -194,56 +197,99 @@ Status ShapeForNode(const TransformFuncContext& context, Status FoldConstants(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { - // Some older GraphDefs have saved _output_shapes attributes which are out of - // date and cause import errors, so clean them up first. - GraphDef cleaned_graph_def; - RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def); - - // Set specified shapes. - for (NodeDef& node : *cleaned_graph_def.mutable_node()) { - TensorShape shape; - bool has_shape_specified; - TF_RETURN_IF_ERROR( - ShapeForNode(context, node.name(), &shape, &has_shape_specified)); - if (has_shape_specified) { - SetNodeAttr("shape", shape, &node); - } - } - Graph input_graph(OpRegistry::Global()); + TF_RETURN_IF_ERROR(input_graph.AddFunctionLibrary(input_graph_def.library())); + ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry()); - shape_refiner.set_require_shape_inference_fns(true); + shape_refiner.set_require_shape_inference_fns(false); shape_refiner.set_disable_constant_propagation(false); - ImportGraphDefOptions import_opts; - TF_RETURN_IF_ERROR(ImportGraphDef(import_opts, cleaned_graph_def, - &input_graph, &shape_refiner)); - DeviceAttributes device_attributes; - subgraph::RewriteGraphMetadata metadata; - TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( - &input_graph, context.input_names, context.output_names, {}, - device_attributes, false /* use_function_convention */, &metadata)); + shape_refiner.set_function_library_for_shape_inference( + &input_graph.flib_def()); - ConstantFoldingOptions cf_opts; + bool clear_output_shapes; + TF_RETURN_IF_ERROR(context.GetOneBoolParameter("clear_output_shapes", true, + &clear_output_shapes)); + if (clear_output_shapes) { + // Some older GraphDefs have saved _output_shapes attributes which are out + // of date and cause import errors, so clean them up first. + GraphDef cleaned_graph_def; + RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def); + + // Set specified shapes. + for (NodeDef& node : *cleaned_graph_def.mutable_node()) { + TensorShape shape; + bool has_shape_specified; + TF_RETURN_IF_ERROR( + ShapeForNode(context, node.name(), &shape, &has_shape_specified)); + if (has_shape_specified) { + SetNodeAttr("shape", shape, &node); + } + } + + TF_RETURN_IF_ERROR( + ImportGraphDef({}, cleaned_graph_def, &input_graph, &shape_refiner)); + } else { + TF_RETURN_IF_ERROR( + ImportGraphDef({}, input_graph_def, &input_graph, &shape_refiner)); + } + + // Sorted array of input names as lookup table. + std::vector<TensorId> input_names; + input_names.reserve(context.input_names.size()); + std::transform(context.input_names.begin(), context.input_names.end(), + std::back_inserter(input_names), + [](const string& name) { return ParseTensorName(name); }); + + const auto compare = [](TensorId lhs, TensorId rhs) { + return lhs.first < rhs.first; + }; + + std::sort(input_names.begin(), input_names.end(), compare); // Set statically inferred shapes. std::unordered_map<string, std::vector<PartialTensorShape>> shape_map; for (const Node* const node : input_graph.nodes()) { auto ctx = shape_refiner.GetContext(node); - if (ctx == nullptr) continue; + if (ctx == nullptr) { + continue; + } - std::vector<PartialTensorShape>* partial_shapes = &shape_map[node->name()]; + std::vector<PartialTensorShape>& partial_shapes = shape_map[node->name()]; if (ctx->num_outputs() <= 0) continue; - partial_shapes->resize(ctx->num_outputs()); + partial_shapes.resize(ctx->num_outputs()); // Check all outputs. for (const Edge* out_edge : node->out_edges()) { if (out_edge->IsControlEdge()) continue; const int output_idx = out_edge->src_output(); - TF_RETURN_IF_ERROR(ShapeHandleToTensorShape( - ctx->output(output_idx), ctx, &(*partial_shapes)[output_idx])); + TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(ctx->output(output_idx), ctx, + &partial_shapes[output_idx])); + } + + // RewriteGraphForExecution() will add a Recv node for each input. Shape + // refiner does not include shape information of these Recv nodes. Therefore + // we add entries for Recv nodes here. + const auto pair = std::equal_range(input_names.begin(), input_names.end(), + TensorId{node->name(), 0}, compare); + for (auto it = pair.first; it != pair.second; ++it) { + const string recv_name = + strings::StrCat("_recv_", it->first, "_", it->second); + auto& recv_partial_shapes = shape_map[recv_name]; + // For whatever reason (for example, name collision) if the map entry was + // already there, then do nothing. + if (recv_partial_shapes.empty()) { + recv_partial_shapes.push_back(partial_shapes[it->second]); + } } } + + subgraph::RewriteGraphMetadata unused_metadata; + TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( + &input_graph, context.input_names, context.output_names, {}, {}, + false /* use_function_convention */, &unused_metadata)); + + ConstantFoldingOptions cf_opts; cf_opts.shape_map = &shape_map; // Exclude specified nodes from constant folding. diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index d4100a652f..fd4188a6a4 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -20,6 +20,8 @@ limitations under the License. #include "tensorflow/cc/ops/nn_ops.h" #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -71,7 +73,7 @@ class ConstantFoldingTest : public ::testing::Test { test::FillIota<float>(&placeholder_tensor, 1.0f); TestConstantFolding(graph_def, {{"placeholder_expect_remains", placeholder_tensor}}, - {}, {"output_expect_remains"}); + {}, {"output_expect_remains"}, {}); } void TestOpExclusionAdd() { @@ -105,7 +107,7 @@ class ConstantFoldingTest : public ::testing::Test { test::FillIota<float>(&placeholder_tensor, 1.0f); TestConstantFolding(graph_def, {{"placeholder_expect_remains", placeholder_tensor}}, - {"Add"}, {"output_expect_remains"}); + {"Add"}, {"output_expect_remains"}, {}); } void TestShapePropagation() { @@ -129,13 +131,46 @@ class ConstantFoldingTest : public ::testing::Test { test::FillIota<float>(&placeholder_tensor, 1.0); TestConstantFolding(graph_def, {{"placeholder_expect_remains", placeholder_tensor}}, - {}, {"output_expect_remains"}); + {}, {"output_expect_remains"}, {}); + } + + void TestPreserveOutputShapes() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + tensorflow::AttrValue shape_attr; + auto* shape_proto = shape_attr.mutable_list()->add_shape(); + shape_proto->add_dim()->set_size(1); + shape_proto->add_dim()->set_size(1); + shape_proto->add_dim()->set_size(3); + + Output placeholder = + Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + placeholder.node()->AddAttr("_output_shapes", shape_attr); + + Output shape = Shape(root.WithOpName("shape_expect_removed"), placeholder); + Output cast = Cast(root.WithOpName("cast_expect_removed"), shape, DT_FLOAT); + Output mul = + Mul(root.WithOpName("output_expect_remains"), cast, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + Tensor placeholder_tensor(DT_FLOAT, TensorShape({1, 1, 3})); + test::FillIota<float>(&placeholder_tensor, 1.0); + + graph_transforms::TransformFuncContext context; + context.params["clear_output_shapes"] = {"false"}; + TestConstantFolding(graph_def, + {{"placeholder_expect_remains", placeholder_tensor}}, + {}, {"output_expect_remains"}, context); } void TestConstantFolding(const GraphDef& graph_def, std::vector<std::pair<string, Tensor> > inputs, std::vector<string> excluded_ops, - const std::vector<string>& outputs) { + const std::vector<string>& outputs, + graph_transforms::TransformFuncContext context) { std::unique_ptr<tensorflow::Session> unfolded_session( tensorflow::NewSession(tensorflow::SessionOptions())); TF_ASSERT_OK(unfolded_session->Create(graph_def)); @@ -143,7 +178,6 @@ class ConstantFoldingTest : public ::testing::Test { TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors)); GraphDef folded_graph_def; - graph_transforms::TransformFuncContext context; for (const std::pair<string, Tensor>& input : inputs) { context.input_names.push_back(input.first); } @@ -269,6 +303,10 @@ TEST_F(ConstantFoldingTest, TestOpExclusionAdd) { TestOpExclusionAdd(); } TEST_F(ConstantFoldingTest, TestShapePropagation) { TestShapePropagation(); } +TEST_F(ConstantFoldingTest, TestPreserveOutputShapes) { + TestPreserveOutputShapes(); +} + TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); } TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); } |