aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-03-21 09:59:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-21 10:02:06 -0700
commit911225a7eaf2872472484bce5f717d287a0e3224 (patch)
treedebc33f27600ec7d5536267f23199c1de75866ea /tensorflow/python/grappler
parent335c782f5c504e36e496a33180d8243760a4001c (diff)
Added an option to run shape analysis assuming the shapes of the feed nodes are
valid. PiperOrigin-RevId: 189923541
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/model_analyzer.cc5
-rw-r--r--tensorflow/python/grappler/model_analyzer.h2
-rw-r--r--tensorflow/python/grappler/model_analyzer.i8
-rw-r--r--tensorflow/python/grappler/model_analyzer.py5
4 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc
index d23eb811ac..5a76cdd8fb 100644
--- a/tensorflow/python/grappler/model_analyzer.cc
+++ b/tensorflow/python/grappler/model_analyzer.cc
@@ -26,9 +26,10 @@ namespace grappler {
ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {}
-Status ModelAnalyzer::GenerateReport(bool debug, std::ostream& os) {
+Status ModelAnalyzer::GenerateReport(bool debug, bool assume_valid_feeds,
+ std::ostream& os) {
GraphProperties properties(item_);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ TF_RETURN_IF_ERROR(properties.InferStatically(assume_valid_feeds));
for (const auto& node : item_.MainOpsFanin()) {
PrintNodeInfo(node, properties, debug, os);
diff --git a/tensorflow/python/grappler/model_analyzer.h b/tensorflow/python/grappler/model_analyzer.h
index 5bc551927d..97ffafabe1 100644
--- a/tensorflow/python/grappler/model_analyzer.h
+++ b/tensorflow/python/grappler/model_analyzer.h
@@ -31,7 +31,7 @@ class GraphProperties;
class ModelAnalyzer {
public:
explicit ModelAnalyzer(const GrapplerItem& item);
- Status GenerateReport(bool debug, std::ostream& os);
+ Status GenerateReport(bool debug, bool assume_valid_feeds, std::ostream& os);
private:
void PrintNodeInfo(const NodeDef* node, const GraphProperties& properties,
diff --git a/tensorflow/python/grappler/model_analyzer.i b/tensorflow/python/grappler/model_analyzer.i
index 7c3a692d0e..4955780764 100644
--- a/tensorflow/python/grappler/model_analyzer.i
+++ b/tensorflow/python/grappler/model_analyzer.i
@@ -40,7 +40,8 @@ limitations under the License.
%}
%{
-string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug) {
+string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph,
+ bool assume_valid_feeds, bool debug) {
tensorflow::grappler::ItemConfig cfg;
cfg.apply_optimizations = false;
std::unique_ptr<tensorflow::grappler::GrapplerItem> item =
@@ -53,10 +54,11 @@ string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug
tensorflow::grappler::ModelAnalyzer analyzer(*item);
std::stringstream os;
- analyzer.GenerateReport(debug, os);
+ analyzer.GenerateReport(debug, assume_valid_feeds, os);
return os.str();
}
%}
-string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug);
+string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph,
+ bool assume_valid_feeds, bool debug);
diff --git a/tensorflow/python/grappler/model_analyzer.py b/tensorflow/python/grappler/model_analyzer.py
index 535889e1c4..98cdc57850 100644
--- a/tensorflow/python/grappler/model_analyzer.py
+++ b/tensorflow/python/grappler/model_analyzer.py
@@ -22,11 +22,12 @@ from tensorflow.python import pywrap_tensorflow as tf_wrap
from tensorflow.python.framework import errors
-def GenerateModelReport(metagraph, debug=False):
+def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False):
"""Report what's known statically about each node in the provided metagraph.
Args:
metagraph: A TensorFlow MetaGraphDef.
+ assume_valid_feeds: If True, assume that the shape of the fed nodes is valid
debug: Add some information useful for debugging.
Returns:
@@ -34,6 +35,6 @@ def GenerateModelReport(metagraph, debug=False):
"""
with errors.raise_exception_on_not_ok_status():
ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(),
- debug)
+ assume_valid_feeds, debug)
return ret_from_swig