aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-12 22:21:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-12 22:25:30 -0700
commita06b378194780c30ee695e9fe9a5b77aaf8bf1f4 (patch)
tree2bf138b84c9fed6a8dd9dafce21e1840ed68fdb4 /tensorflow/tools/graph_transforms
parent99dc61dbe520b43fcc1919124d2281d3c4fdfa85 (diff)
Add "clear_output_shapes" option to FoldConstants transformer in
tools/graph_transforms. By setting this option to false, the transformer will not strip off the shape information stored as attributes. PiperOrigin-RevId: 172057283
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/README.md7
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc108
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc48
3 files changed, 126 insertions, 37 deletions
diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md
index 00297f07b7..c7f7eca257 100644
--- a/tensorflow/tools/graph_transforms/README.md
+++ b/tensorflow/tools/graph_transforms/README.md
@@ -385,7 +385,12 @@ input is collapsed down into a simple constant.
### fold_constants
-Args: None \
+Args:
+
+* clear_output_shapes: Clears tensor shape information saved as attributes.
+ Some older graphs containes out-of-date information and may cause import
+ errors. Defaults to true.
+
Prerequisites: None
Looks for any sub-graphs within the model that always evaluate to constant
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index 0f5bc2bcdd..30290c7a16 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
+#include <algorithm>
+#include <iterator>
+
#include "tensorflow/core/common_runtime/constant_folding.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -194,56 +197,99 @@ Status ShapeForNode(const TransformFuncContext& context,
Status FoldConstants(const GraphDef& input_graph_def,
const TransformFuncContext& context,
GraphDef* output_graph_def) {
- // Some older GraphDefs have saved _output_shapes attributes which are out of
- // date and cause import errors, so clean them up first.
- GraphDef cleaned_graph_def;
- RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def);
-
- // Set specified shapes.
- for (NodeDef& node : *cleaned_graph_def.mutable_node()) {
- TensorShape shape;
- bool has_shape_specified;
- TF_RETURN_IF_ERROR(
- ShapeForNode(context, node.name(), &shape, &has_shape_specified));
- if (has_shape_specified) {
- SetNodeAttr("shape", shape, &node);
- }
- }
-
Graph input_graph(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(input_graph.AddFunctionLibrary(input_graph_def.library()));
+
ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry());
- shape_refiner.set_require_shape_inference_fns(true);
+ shape_refiner.set_require_shape_inference_fns(false);
shape_refiner.set_disable_constant_propagation(false);
- ImportGraphDefOptions import_opts;
- TF_RETURN_IF_ERROR(ImportGraphDef(import_opts, cleaned_graph_def,
- &input_graph, &shape_refiner));
- DeviceAttributes device_attributes;
- subgraph::RewriteGraphMetadata metadata;
- TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
- &input_graph, context.input_names, context.output_names, {},
- device_attributes, false /* use_function_convention */, &metadata));
+ shape_refiner.set_function_library_for_shape_inference(
+ &input_graph.flib_def());
- ConstantFoldingOptions cf_opts;
+ bool clear_output_shapes;
+ TF_RETURN_IF_ERROR(context.GetOneBoolParameter("clear_output_shapes", true,
+ &clear_output_shapes));
+ if (clear_output_shapes) {
+ // Some older GraphDefs have saved _output_shapes attributes which are out
+ // of date and cause import errors, so clean them up first.
+ GraphDef cleaned_graph_def;
+ RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def);
+
+ // Set specified shapes.
+ for (NodeDef& node : *cleaned_graph_def.mutable_node()) {
+ TensorShape shape;
+ bool has_shape_specified;
+ TF_RETURN_IF_ERROR(
+ ShapeForNode(context, node.name(), &shape, &has_shape_specified));
+ if (has_shape_specified) {
+ SetNodeAttr("shape", shape, &node);
+ }
+ }
+
+ TF_RETURN_IF_ERROR(
+ ImportGraphDef({}, cleaned_graph_def, &input_graph, &shape_refiner));
+ } else {
+ TF_RETURN_IF_ERROR(
+ ImportGraphDef({}, input_graph_def, &input_graph, &shape_refiner));
+ }
+
+ // Sorted array of input names as lookup table.
+ std::vector<TensorId> input_names;
+ input_names.reserve(context.input_names.size());
+ std::transform(context.input_names.begin(), context.input_names.end(),
+ std::back_inserter(input_names),
+ [](const string& name) { return ParseTensorName(name); });
+
+ const auto compare = [](TensorId lhs, TensorId rhs) {
+ return lhs.first < rhs.first;
+ };
+
+ std::sort(input_names.begin(), input_names.end(), compare);
// Set statically inferred shapes.
std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
for (const Node* const node : input_graph.nodes()) {
auto ctx = shape_refiner.GetContext(node);
- if (ctx == nullptr) continue;
+ if (ctx == nullptr) {
+ continue;
+ }
- std::vector<PartialTensorShape>* partial_shapes = &shape_map[node->name()];
+ std::vector<PartialTensorShape>& partial_shapes = shape_map[node->name()];
if (ctx->num_outputs() <= 0) continue;
- partial_shapes->resize(ctx->num_outputs());
+ partial_shapes.resize(ctx->num_outputs());
// Check all outputs.
for (const Edge* out_edge : node->out_edges()) {
if (out_edge->IsControlEdge()) continue;
const int output_idx = out_edge->src_output();
- TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(
- ctx->output(output_idx), ctx, &(*partial_shapes)[output_idx]));
+ TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(ctx->output(output_idx), ctx,
+ &partial_shapes[output_idx]));
+ }
+
+ // RewriteGraphForExecution() will add a Recv node for each input. Shape
+ // refiner does not include shape information of these Recv nodes. Therefore
+ // we add entries for Recv nodes here.
+ const auto pair = std::equal_range(input_names.begin(), input_names.end(),
+ TensorId{node->name(), 0}, compare);
+ for (auto it = pair.first; it != pair.second; ++it) {
+ const string recv_name =
+ strings::StrCat("_recv_", it->first, "_", it->second);
+ auto& recv_partial_shapes = shape_map[recv_name];
+ // For whatever reason (for example, name collision) if the map entry was
+ // already there, then do nothing.
+ if (recv_partial_shapes.empty()) {
+ recv_partial_shapes.push_back(partial_shapes[it->second]);
+ }
}
}
+
+ subgraph::RewriteGraphMetadata unused_metadata;
+ TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
+ &input_graph, context.input_names, context.output_names, {}, {},
+ false /* use_function_convention */, &unused_metadata));
+
+ ConstantFoldingOptions cf_opts;
cf_opts.shape_map = &shape_map;
// Exclude specified nodes from constant folding.
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index d4100a652f..fd4188a6a4 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -71,7 +73,7 @@ class ConstantFoldingTest : public ::testing::Test {
test::FillIota<float>(&placeholder_tensor, 1.0f);
TestConstantFolding(graph_def,
{{"placeholder_expect_remains", placeholder_tensor}},
- {}, {"output_expect_remains"});
+ {}, {"output_expect_remains"}, {});
}
void TestOpExclusionAdd() {
@@ -105,7 +107,7 @@ class ConstantFoldingTest : public ::testing::Test {
test::FillIota<float>(&placeholder_tensor, 1.0f);
TestConstantFolding(graph_def,
{{"placeholder_expect_remains", placeholder_tensor}},
- {"Add"}, {"output_expect_remains"});
+ {"Add"}, {"output_expect_remains"}, {});
}
void TestShapePropagation() {
@@ -129,13 +131,46 @@ class ConstantFoldingTest : public ::testing::Test {
test::FillIota<float>(&placeholder_tensor, 1.0);
TestConstantFolding(graph_def,
{{"placeholder_expect_remains", placeholder_tensor}},
- {}, {"output_expect_remains"});
+ {}, {"output_expect_remains"}, {});
+ }
+
+ void TestPreserveOutputShapes() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ tensorflow::AttrValue shape_attr;
+ auto* shape_proto = shape_attr.mutable_list()->add_shape();
+ shape_proto->add_dim()->set_size(1);
+ shape_proto->add_dim()->set_size(1);
+ shape_proto->add_dim()->set_size(3);
+
+ Output placeholder =
+ Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
+ placeholder.node()->AddAttr("_output_shapes", shape_attr);
+
+ Output shape = Shape(root.WithOpName("shape_expect_removed"), placeholder);
+ Output cast = Cast(root.WithOpName("cast_expect_removed"), shape, DT_FLOAT);
+ Output mul =
+ Mul(root.WithOpName("output_expect_remains"), cast, placeholder);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+ Tensor placeholder_tensor(DT_FLOAT, TensorShape({1, 1, 3}));
+ test::FillIota<float>(&placeholder_tensor, 1.0);
+
+ graph_transforms::TransformFuncContext context;
+ context.params["clear_output_shapes"] = {"false"};
+ TestConstantFolding(graph_def,
+ {{"placeholder_expect_remains", placeholder_tensor}},
+ {}, {"output_expect_remains"}, context);
}
void TestConstantFolding(const GraphDef& graph_def,
std::vector<std::pair<string, Tensor> > inputs,
std::vector<string> excluded_ops,
- const std::vector<string>& outputs) {
+ const std::vector<string>& outputs,
+ graph_transforms::TransformFuncContext context) {
std::unique_ptr<tensorflow::Session> unfolded_session(
tensorflow::NewSession(tensorflow::SessionOptions()));
TF_ASSERT_OK(unfolded_session->Create(graph_def));
@@ -143,7 +178,6 @@ class ConstantFoldingTest : public ::testing::Test {
TF_ASSERT_OK(unfolded_session->Run(inputs, outputs, {}, &unfolded_tensors));
GraphDef folded_graph_def;
- graph_transforms::TransformFuncContext context;
for (const std::pair<string, Tensor>& input : inputs) {
context.input_names.push_back(input.first);
}
@@ -269,6 +303,10 @@ TEST_F(ConstantFoldingTest, TestOpExclusionAdd) { TestOpExclusionAdd(); }
TEST_F(ConstantFoldingTest, TestShapePropagation) { TestShapePropagation(); }
+TEST_F(ConstantFoldingTest, TestPreserveOutputShapes) {
+ TestPreserveOutputShapes();
+}
+
TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); }
TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }