diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-10-03 01:04:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-03 01:08:11 -0700 |
commit | 0ea4331690c9f00abfbb634a91520042b7b84a20 (patch) | |
tree | f63f834d30c059055d0b318123e162eb52445688 /tensorflow/tools/graph_transforms | |
parent | 263d025fb6dee974eefb30a51372188fb856d6cc (diff) |
Use shape information in constant propagation.
PiperOrigin-RevId: 170818644
Diffstat (limited to 'tensorflow/tools/graph_transforms')
5 files changed, 144 insertions, 25 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc index f97e485418..0f5bc2bcdd 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" #include "tensorflow/core/common_runtime/constant_folding.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" @@ -133,6 +134,61 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def, return Status::OK(); } +// Converts a shape inference handle to a PartialTensorShape. +Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle, + shape_inference::InferenceContext* context, + PartialTensorShape* shape) { + // The default is already unknown + if (!context->RankKnown(handle)) return Status::OK(); + + std::vector<int64> dims(context->Rank(handle)); + for (int32 i = 0; i < dims.size(); ++i) { + dims[i] = context->Value(context->Dim(handle, i)); + } + return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); +} + +Status ShapeForNode(const TransformFuncContext& context, + const string& node_name, TensorShape* result, + bool* has_shape_specified) { + *has_shape_specified = false; + + // Check to see if we have been given a default for all placeholders. + if (context.params.count("type")) { + if (context.params.at("shape").size() != 1) { + return errors::InvalidArgument( + "You must pass no more than one default 'shape' to " + "fold_constants"); + } + const string& shape_string = context.params.at("shape")[0]; + TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result)); + *has_shape_specified = true; + } + + // See if there's a particular type specified for this placeholder. + if (context.params.count("name") || context.params.count("type_for_name")) { + if (!context.params.count("name") || + !context.params.count("type_for_name") || + (context.params.at("type_for_name").size() != + context.params.at("name").size())) { + return errors::InvalidArgument( + "You must pass a 'shape_for_name' arg for every 'name', e.g. " + "fold_constants(name=foo, shape_for_name=\"2,2,1\", name=bar, " + "shape_for_name=\"1\""); + } + const int name_count = context.params.at("name").size(); + for (int i = 0; i < name_count; ++i) { + if (context.params.at("name")[i] == node_name) { + const string& shape_string = context.params.at("shape_for_name")[i]; + TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result)); + *has_shape_specified = true; + } + } + } + + return Status::OK(); +} + // Converts any sub-graphs that can be resolved into constant expressions into // single Const ops. Status FoldConstants(const GraphDef& input_graph_def, @@ -142,18 +198,55 @@ Status FoldConstants(const GraphDef& input_graph_def, // date and cause import errors, so clean them up first. GraphDef cleaned_graph_def; RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def); + + // Set specified shapes. + for (NodeDef& node : *cleaned_graph_def.mutable_node()) { + TensorShape shape; + bool has_shape_specified; + TF_RETURN_IF_ERROR( + ShapeForNode(context, node.name(), &shape, &has_shape_specified)); + if (has_shape_specified) { + SetNodeAttr("shape", shape, &node); + } + } + Graph input_graph(OpRegistry::Global()); + ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry()); + shape_refiner.set_require_shape_inference_fns(true); + shape_refiner.set_disable_constant_propagation(false); ImportGraphDefOptions import_opts; - TF_RETURN_IF_ERROR( - ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr)); + TF_RETURN_IF_ERROR(ImportGraphDef(import_opts, cleaned_graph_def, + &input_graph, &shape_refiner)); DeviceAttributes device_attributes; subgraph::RewriteGraphMetadata metadata; TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution( &input_graph, context.input_names, context.output_names, {}, device_attributes, false /* use_function_convention */, &metadata)); - bool was_mutated; - // Exclude specified nodes from constant folding. + ConstantFoldingOptions cf_opts; + + // Set statically inferred shapes. + std::unordered_map<string, std::vector<PartialTensorShape>> shape_map; + for (const Node* const node : input_graph.nodes()) { + auto ctx = shape_refiner.GetContext(node); + if (ctx == nullptr) continue; + + std::vector<PartialTensorShape>* partial_shapes = &shape_map[node->name()]; + if (ctx->num_outputs() <= 0) continue; + partial_shapes->resize(ctx->num_outputs()); + + // Check all outputs. + for (const Edge* out_edge : node->out_edges()) { + if (out_edge->IsControlEdge()) continue; + + const int output_idx = out_edge->src_output(); + TF_RETURN_IF_ERROR(ShapeHandleToTensorShape( + ctx->output(output_idx), ctx, &(*partial_shapes)[output_idx])); + } + } + cf_opts.shape_map = &shape_map; + + // Exclude specified nodes from constant folding. if (context.params.count("exclude_op") > 0) { const auto& excluded_nodes = context.params.at("exclude_op"); const std::set<string> excluded_nodes_set(excluded_nodes.begin(), @@ -163,6 +256,9 @@ Status FoldConstants(const GraphDef& input_graph_def, excluded_nodes_set.end(); }; } + + // Constant folding. + bool was_mutated; TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr, &input_graph, &was_mutated)); GraphDef folded_graph_def; diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc index 14e2c01c7c..d4100a652f 100644 --- a/tensorflow/tools/graph_transforms/fold_constants_test.cc +++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc @@ -108,6 +108,30 @@ class ConstantFoldingTest : public ::testing::Test { {"Add"}, {"output_expect_remains"}); } + void TestShapePropagation() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Output placeholder = + Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT); + Output a_const = + Const(root.WithOpName("a_expect_removed"), + Input::Initializer({1, 1, 1}, TensorShape({1, 1, 3}))); + Output shape = Shape(root.WithOpName("shape_expect_removed"), a_const); + Output cast = Cast(root.WithOpName("cast_expect_removed"), shape, DT_FLOAT); + Output mul = + Mul(root.WithOpName("output_expect_remains"), cast, placeholder); + + GraphDef graph_def; + TF_ASSERT_OK(root.ToGraphDef(&graph_def)); + + Tensor placeholder_tensor(DT_FLOAT, TensorShape({3})); + test::FillIota<float>(&placeholder_tensor, 1.0); + TestConstantFolding(graph_def, + {{"placeholder_expect_remains", placeholder_tensor}}, + {}, {"output_expect_remains"}); + } + void TestConstantFolding(const GraphDef& graph_def, std::vector<std::pair<string, Tensor> > inputs, std::vector<string> excluded_ops, @@ -243,6 +267,8 @@ TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); } TEST_F(ConstantFoldingTest, TestOpExclusionAdd) { TestOpExclusionAdd(); } +TEST_F(ConstantFoldingTest, TestShapePropagation) { TestShapePropagation(); } + TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); } TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); } diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes.cc index 08de934916..ae9d0aa209 100644 --- a/tensorflow/tools/graph_transforms/strip_unused_nodes.cc +++ b/tensorflow/tools/graph_transforms/strip_unused_nodes.cc @@ -74,19 +74,6 @@ Status TypeForPlaceholder(const TransformFuncContext& context, return Status::OK(); } -// Takes a comma-separated string of numbers and parses them into a shape. -bool TensorShapeFromString(const string& shape_string, TensorShape* result) { - if (shape_string.empty()) { - return false; - } - std::vector<int64> dims; - if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) { - return false; - } - *result = TensorShape(dims); - return true; -} - Status ShapeForPlaceholder(const TransformFuncContext& context, const string& node_name, TensorShape* result) { // If we don't find anything else, return scalar. @@ -100,10 +87,7 @@ Status ShapeForPlaceholder(const TransformFuncContext& context, "strip_unused_nodes"); } const string& shape_string = context.params.at("shape")[0]; - if (!TensorShapeFromString(shape_string, result)) { - return errors::InvalidArgument("Couldn't understand shape argument '", - shape_string, "'"); - } + TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result)); } // See if there's a particular type specified for this placeholder. @@ -121,10 +105,7 @@ Status ShapeForPlaceholder(const TransformFuncContext& context, for (int i = 0; i < name_count; ++i) { if (context.params.at("name")[i] == node_name) { const string& shape_string = context.params.at("shape_for_name")[i]; - if (!TensorShapeFromString(shape_string, result)) { - return errors::InvalidArgument("Couldn't understand shape argument '", - shape_string, "'"); - } + TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result)); } } } diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc index bd1e4c90c0..55f28a9e1d 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.cc +++ b/tensorflow/tools/graph_transforms/transform_utils.cc @@ -586,6 +586,19 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, return Status::OK(); } +Status TensorShapeFromString(const string& shape_string, TensorShape* result) { + if (shape_string.empty()) { + return errors::InvalidArgument("Specificed shape is empty."); + } + std::vector<int64> dims; + if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) { + return errors::InvalidArgument("Could parse as shape: '", shape_string, + "'"); + } + *result = TensorShape(dims); + return Status::OK(); +} + int TransformFuncContext::CountParameters(const string& name) const { if (params.count(name)) { return params.at(name).size(); diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h index c0fb492412..47c8aaed2c 100644 --- a/tensorflow/tools/graph_transforms/transform_utils.h +++ b/tensorflow/tools/graph_transforms/transform_utils.h @@ -133,6 +133,9 @@ Status IsGraphValid(const GraphDef& graph_def); Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs, DataTypeVector* outputs); +// Takes a comma-separated string of numbers and parses them into a shape. +Status TensorShapeFromString(const string& shape_string, TensorShape* result); + // This is used to spot particular subgraphs in a larger model. To use it, // create a pattern like: // OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}}); |