aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/graph_transforms
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-10-03 01:04:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-03 01:08:11 -0700
commit0ea4331690c9f00abfbb634a91520042b7b84a20 (patch)
treef63f834d30c059055d0b318123e162eb52445688 /tensorflow/tools/graph_transforms
parent263d025fb6dee974eefb30a51372188fb856d6cc (diff)
Use shape information in constant propagation.
PiperOrigin-RevId: 170818644
Diffstat (limited to 'tensorflow/tools/graph_transforms')
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_lib.cc104
-rw-r--r--tensorflow/tools/graph_transforms/fold_constants_test.cc26
-rw-r--r--tensorflow/tools/graph_transforms/strip_unused_nodes.cc23
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.cc13
-rw-r--r--tensorflow/tools/graph_transforms/transform_utils.h3
5 files changed, 144 insertions, 25 deletions
diff --git a/tensorflow/tools/graph_transforms/fold_constants_lib.cc b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
index f97e485418..0f5bc2bcdd 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_lib.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_lib.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
#include "tensorflow/core/common_runtime/constant_folding.h"
+#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/subgraph.h"
@@ -133,6 +134,61 @@ Status RemoveUnusedNodes(const GraphDef& input_graph_def,
return Status::OK();
}
+// Converts a shape inference handle to a PartialTensorShape.
+Status ShapeHandleToTensorShape(const shape_inference::ShapeHandle& handle,
+ shape_inference::InferenceContext* context,
+ PartialTensorShape* shape) {
+ // The default is already unknown
+ if (!context->RankKnown(handle)) return Status::OK();
+
+ std::vector<int64> dims(context->Rank(handle));
+ for (int32 i = 0; i < dims.size(); ++i) {
+ dims[i] = context->Value(context->Dim(handle, i));
+ }
+ return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
+}
+
+Status ShapeForNode(const TransformFuncContext& context,
+ const string& node_name, TensorShape* result,
+ bool* has_shape_specified) {
+ *has_shape_specified = false;
+
+ // Check to see if we have been given a default for all placeholders.
+ if (context.params.count("type")) {
+ if (context.params.at("shape").size() != 1) {
+ return errors::InvalidArgument(
+ "You must pass no more than one default 'shape' to "
+ "fold_constants");
+ }
+ const string& shape_string = context.params.at("shape")[0];
+ TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
+ *has_shape_specified = true;
+ }
+
+ // See if there's a particular type specified for this placeholder.
+ if (context.params.count("name") || context.params.count("type_for_name")) {
+ if (!context.params.count("name") ||
+ !context.params.count("type_for_name") ||
+ (context.params.at("type_for_name").size() !=
+ context.params.at("name").size())) {
+ return errors::InvalidArgument(
+ "You must pass a 'shape_for_name' arg for every 'name', e.g. "
+ "fold_constants(name=foo, shape_for_name=\"2,2,1\", name=bar, "
+ "shape_for_name=\"1\"");
+ }
+ const int name_count = context.params.at("name").size();
+ for (int i = 0; i < name_count; ++i) {
+ if (context.params.at("name")[i] == node_name) {
+ const string& shape_string = context.params.at("shape_for_name")[i];
+ TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
+ *has_shape_specified = true;
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
// Converts any sub-graphs that can be resolved into constant expressions into
// single Const ops.
Status FoldConstants(const GraphDef& input_graph_def,
@@ -142,18 +198,55 @@ Status FoldConstants(const GraphDef& input_graph_def,
// date and cause import errors, so clean them up first.
GraphDef cleaned_graph_def;
RemoveAttributes(input_graph_def, {"_output_shapes"}, &cleaned_graph_def);
+
+ // Set specified shapes.
+ for (NodeDef& node : *cleaned_graph_def.mutable_node()) {
+ TensorShape shape;
+ bool has_shape_specified;
+ TF_RETURN_IF_ERROR(
+ ShapeForNode(context, node.name(), &shape, &has_shape_specified));
+ if (has_shape_specified) {
+ SetNodeAttr("shape", shape, &node);
+ }
+ }
+
Graph input_graph(OpRegistry::Global());
+ ShapeRefiner shape_refiner(input_graph.versions(), input_graph.op_registry());
+ shape_refiner.set_require_shape_inference_fns(true);
+ shape_refiner.set_disable_constant_propagation(false);
ImportGraphDefOptions import_opts;
- TF_RETURN_IF_ERROR(
- ImportGraphDef(import_opts, cleaned_graph_def, &input_graph, nullptr));
+ TF_RETURN_IF_ERROR(ImportGraphDef(import_opts, cleaned_graph_def,
+ &input_graph, &shape_refiner));
DeviceAttributes device_attributes;
subgraph::RewriteGraphMetadata metadata;
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
&input_graph, context.input_names, context.output_names, {},
device_attributes, false /* use_function_convention */, &metadata));
- bool was_mutated;
- // Exclude specified nodes from constant folding.
+
ConstantFoldingOptions cf_opts;
+
+ // Set statically inferred shapes.
+ std::unordered_map<string, std::vector<PartialTensorShape>> shape_map;
+ for (const Node* const node : input_graph.nodes()) {
+ auto ctx = shape_refiner.GetContext(node);
+ if (ctx == nullptr) continue;
+
+ std::vector<PartialTensorShape>* partial_shapes = &shape_map[node->name()];
+ if (ctx->num_outputs() <= 0) continue;
+ partial_shapes->resize(ctx->num_outputs());
+
+ // Check all outputs.
+ for (const Edge* out_edge : node->out_edges()) {
+ if (out_edge->IsControlEdge()) continue;
+
+ const int output_idx = out_edge->src_output();
+ TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(
+ ctx->output(output_idx), ctx, &(*partial_shapes)[output_idx]));
+ }
+ }
+ cf_opts.shape_map = &shape_map;
+
+ // Exclude specified nodes from constant folding.
if (context.params.count("exclude_op") > 0) {
const auto& excluded_nodes = context.params.at("exclude_op");
const std::set<string> excluded_nodes_set(excluded_nodes.begin(),
@@ -163,6 +256,9 @@ Status FoldConstants(const GraphDef& input_graph_def,
excluded_nodes_set.end();
};
}
+
+ // Constant folding.
+ bool was_mutated;
TF_RETURN_IF_ERROR(ConstantFold(cf_opts, nullptr, Env::Default(), nullptr,
&input_graph, &was_mutated));
GraphDef folded_graph_def;
diff --git a/tensorflow/tools/graph_transforms/fold_constants_test.cc b/tensorflow/tools/graph_transforms/fold_constants_test.cc
index 14e2c01c7c..d4100a652f 100644
--- a/tensorflow/tools/graph_transforms/fold_constants_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_constants_test.cc
@@ -108,6 +108,30 @@ class ConstantFoldingTest : public ::testing::Test {
{"Add"}, {"output_expect_remains"});
}
+ void TestShapePropagation() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Output placeholder =
+ Placeholder(root.WithOpName("placeholder_expect_remains"), DT_FLOAT);
+ Output a_const =
+ Const(root.WithOpName("a_expect_removed"),
+ Input::Initializer({1, 1, 1}, TensorShape({1, 1, 3})));
+ Output shape = Shape(root.WithOpName("shape_expect_removed"), a_const);
+ Output cast = Cast(root.WithOpName("cast_expect_removed"), shape, DT_FLOAT);
+ Output mul =
+ Mul(root.WithOpName("output_expect_remains"), cast, placeholder);
+
+ GraphDef graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&graph_def));
+
+ Tensor placeholder_tensor(DT_FLOAT, TensorShape({3}));
+ test::FillIota<float>(&placeholder_tensor, 1.0);
+ TestConstantFolding(graph_def,
+ {{"placeholder_expect_remains", placeholder_tensor}},
+ {}, {"output_expect_remains"});
+ }
+
void TestConstantFolding(const GraphDef& graph_def,
std::vector<std::pair<string, Tensor> > inputs,
std::vector<string> excluded_ops,
@@ -243,6 +267,8 @@ TEST_F(ConstantFoldingTest, TestSimpleAdd) { TestSimpleAdd(); }
TEST_F(ConstantFoldingTest, TestOpExclusionAdd) { TestOpExclusionAdd(); }
+TEST_F(ConstantFoldingTest, TestShapePropagation) { TestShapePropagation(); }
+
TEST_F(ConstantFoldingTest, TestReplaceSendRecvs) { TestReplaceSendRecvs(); }
TEST_F(ConstantFoldingTest, TestRemoveUnusedNodes) { TestRemoveUnusedNodes(); }
diff --git a/tensorflow/tools/graph_transforms/strip_unused_nodes.cc b/tensorflow/tools/graph_transforms/strip_unused_nodes.cc
index 08de934916..ae9d0aa209 100644
--- a/tensorflow/tools/graph_transforms/strip_unused_nodes.cc
+++ b/tensorflow/tools/graph_transforms/strip_unused_nodes.cc
@@ -74,19 +74,6 @@ Status TypeForPlaceholder(const TransformFuncContext& context,
return Status::OK();
}
-// Takes a comma-separated string of numbers and parses them into a shape.
-bool TensorShapeFromString(const string& shape_string, TensorShape* result) {
- if (shape_string.empty()) {
- return false;
- }
- std::vector<int64> dims;
- if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) {
- return false;
- }
- *result = TensorShape(dims);
- return true;
-}
-
Status ShapeForPlaceholder(const TransformFuncContext& context,
const string& node_name, TensorShape* result) {
// If we don't find anything else, return scalar.
@@ -100,10 +87,7 @@ Status ShapeForPlaceholder(const TransformFuncContext& context,
"strip_unused_nodes");
}
const string& shape_string = context.params.at("shape")[0];
- if (!TensorShapeFromString(shape_string, result)) {
- return errors::InvalidArgument("Couldn't understand shape argument '",
- shape_string, "'");
- }
+ TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
}
// See if there's a particular type specified for this placeholder.
@@ -121,10 +105,7 @@ Status ShapeForPlaceholder(const TransformFuncContext& context,
for (int i = 0; i < name_count; ++i) {
if (context.params.at("name")[i] == node_name) {
const string& shape_string = context.params.at("shape_for_name")[i];
- if (!TensorShapeFromString(shape_string, result)) {
- return errors::InvalidArgument("Couldn't understand shape argument '",
- shape_string, "'");
- }
+ TF_RETURN_IF_ERROR(TensorShapeFromString(shape_string, result));
}
}
}
diff --git a/tensorflow/tools/graph_transforms/transform_utils.cc b/tensorflow/tools/graph_transforms/transform_utils.cc
index bd1e4c90c0..55f28a9e1d 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.cc
+++ b/tensorflow/tools/graph_transforms/transform_utils.cc
@@ -586,6 +586,19 @@ Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
return Status::OK();
}
+Status TensorShapeFromString(const string& shape_string, TensorShape* result) {
+ if (shape_string.empty()) {
+ return errors::InvalidArgument("Specificed shape is empty.");
+ }
+ std::vector<int64> dims;
+ if (!str_util::SplitAndParseAsInts(shape_string, ',', &dims)) {
+ return errors::InvalidArgument("Could parse as shape: '", shape_string,
+ "'");
+ }
+ *result = TensorShape(dims);
+ return Status::OK();
+}
+
int TransformFuncContext::CountParameters(const string& name) const {
if (params.count(name)) {
return params.at(name).size();
diff --git a/tensorflow/tools/graph_transforms/transform_utils.h b/tensorflow/tools/graph_transforms/transform_utils.h
index c0fb492412..47c8aaed2c 100644
--- a/tensorflow/tools/graph_transforms/transform_utils.h
+++ b/tensorflow/tools/graph_transforms/transform_utils.h
@@ -133,6 +133,9 @@ Status IsGraphValid(const GraphDef& graph_def);
Status GetInOutTypes(const NodeDef& node_def, DataTypeVector* inputs,
DataTypeVector* outputs);
+// Takes a comma-separated string of numbers and parses them into a shape.
+Status TensorShapeFromString(const string& shape_string, TensorShape* result);
+
// This is used to spot particular subgraphs in a larger model. To use it,
// create a pattern like:
// OpTypePattern pattern({"Conv2D", {{"ResizeBilinear", {{"MirrorPad"}}}}});