diff options
author | 2018-03-21 09:59:18 -0700 | |
---|---|---|
committer | 2018-03-21 10:02:06 -0700 | |
commit | 911225a7eaf2872472484bce5f717d287a0e3224 (patch) | |
tree | debc33f27600ec7d5536267f23199c1de75866ea /tensorflow/python/grappler | |
parent | 335c782f5c504e36e496a33180d8243760a4001c (diff) |
Added an option to run shape analysis assuming the shapes of the feed nodes are
valid.
PiperOrigin-RevId: 189923541
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/model_analyzer.cc | 5 | ||||
-rw-r--r-- | tensorflow/python/grappler/model_analyzer.h | 2 | ||||
-rw-r--r-- | tensorflow/python/grappler/model_analyzer.i | 8 | ||||
-rw-r--r-- | tensorflow/python/grappler/model_analyzer.py | 5 |
4 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc index d23eb811ac..5a76cdd8fb 100644 --- a/tensorflow/python/grappler/model_analyzer.cc +++ b/tensorflow/python/grappler/model_analyzer.cc @@ -26,9 +26,10 @@ namespace grappler { ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {} -Status ModelAnalyzer::GenerateReport(bool debug, std::ostream& os) { +Status ModelAnalyzer::GenerateReport(bool debug, bool assume_valid_feeds, + std::ostream& os) { GraphProperties properties(item_); - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + TF_RETURN_IF_ERROR(properties.InferStatically(assume_valid_feeds)); for (const auto& node : item_.MainOpsFanin()) { PrintNodeInfo(node, properties, debug, os); diff --git a/tensorflow/python/grappler/model_analyzer.h b/tensorflow/python/grappler/model_analyzer.h index 5bc551927d..97ffafabe1 100644 --- a/tensorflow/python/grappler/model_analyzer.h +++ b/tensorflow/python/grappler/model_analyzer.h @@ -31,7 +31,7 @@ class GraphProperties; class ModelAnalyzer { public: explicit ModelAnalyzer(const GrapplerItem& item); - Status GenerateReport(bool debug, std::ostream& os); + Status GenerateReport(bool debug, bool assume_valid_feeds, std::ostream& os); private: void PrintNodeInfo(const NodeDef* node, const GraphProperties& properties, diff --git a/tensorflow/python/grappler/model_analyzer.i b/tensorflow/python/grappler/model_analyzer.i index 7c3a692d0e..4955780764 100644 --- a/tensorflow/python/grappler/model_analyzer.i +++ b/tensorflow/python/grappler/model_analyzer.i @@ -40,7 +40,8 @@ limitations under the License. %} %{ -string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug) { +string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, + bool assume_valid_feeds, bool debug) { tensorflow::grappler::ItemConfig cfg; cfg.apply_optimizations = false; std::unique_ptr<tensorflow::grappler::GrapplerItem> item = @@ -53,10 +54,11 @@ string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug tensorflow::grappler::ModelAnalyzer analyzer(*item); std::stringstream os; - analyzer.GenerateReport(debug, os); + analyzer.GenerateReport(debug, assume_valid_feeds, os); return os.str(); } %} -string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, bool debug); +string GenerateModelReport(const tensorflow::MetaGraphDef& metagraph, + bool assume_valid_feeds, bool debug); diff --git a/tensorflow/python/grappler/model_analyzer.py b/tensorflow/python/grappler/model_analyzer.py index 535889e1c4..98cdc57850 100644 --- a/tensorflow/python/grappler/model_analyzer.py +++ b/tensorflow/python/grappler/model_analyzer.py @@ -22,11 +22,12 @@ from tensorflow.python import pywrap_tensorflow as tf_wrap from tensorflow.python.framework import errors -def GenerateModelReport(metagraph, debug=False): +def GenerateModelReport(metagraph, assume_valid_feeds=True, debug=False): """Report what's known statically about each node in the provided metagraph. Args: metagraph: A TensorFlow MetaGraphDef. + assume_valid_feeds: If True, assume that the shape of the fed nodes is valid debug: Add some information useful for debugging. Returns: @@ -34,6 +35,6 @@ def GenerateModelReport(metagraph, debug=False): """ with errors.raise_exception_on_not_ok_status(): ret_from_swig = tf_wrap.GenerateModelReport(metagraph.SerializeToString(), - debug) + assume_valid_feeds, debug) return ret_from_swig |