diff options
author | Yao Zhang <yaozhang@google.com> | 2018-02-20 11:03:08 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-20 11:16:18 -0800 |
commit | 10386781aebfacd5366bf6af9fc40db35625232e (patch) | |
tree | c815d53a4a247155c0a88d233ffbfb2f182dff7e /tensorflow/python/grappler | |
parent | 65ac3dfa9a48d209edd50178b7477bbfe0435633 (diff) |
Support multiple fetch nodes and add a flag for memory report.
PiperOrigin-RevId: 186329308
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/cost_analyzer_tool.py | 28 |
1 files changed, 19 insertions, 9 deletions
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 86db87d515..0db3c30a27 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -35,7 +35,8 @@ from tensorflow.python.platform import gfile from tensorflow.python.training import saver -def main(_): +def get_metagraph(): + """Constructs and returns a MetaGraphDef from the input file.""" if FLAGS.metagraphdef: with gfile.GFile(FLAGS.metagraphdef) as meta_file: metagraph = meta_graph_pb2.MetaGraphDef() @@ -45,7 +46,8 @@ def main(_): metagraph.ParseFromString(meta_file.read()) if FLAGS.fetch is not None: fetch_collection = meta_graph_pb2.CollectionDef() - fetch_collection.node_list.value.append(FLAGS.fetch) + for fetch in FLAGS.fetch.split(","): + fetch_collection.node_list.value.append(fetch) metagraph.collection_def["train_op"].CopyFrom(fetch_collection) else: with gfile.GFile(FLAGS.graphdef) as graph_file: @@ -56,11 +58,16 @@ def main(_): graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() - fetch = graph.get_operation_by_name(FLAGS.fetch) - graph.add_to_collection("train_op", fetch) + for fetch in FLAGS.fetch.split(","): + fetch_op = graph.get_operation_by_name(fetch) + graph.add_to_collection("train_op", fetch_op) metagraph = saver.export_meta_graph( graph_def=graph.as_graph_def(), graph=graph) + return metagraph + +def main(_): + metagraph = get_metagraph() rewriter_config = rewriter_config_pb2.RewriterConfig() if FLAGS.rewriter_config is not None: text_format.Merge(FLAGS.rewriter_config, rewriter_config) @@ -69,8 +76,9 @@ def main(_): report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) print(report) - report = cost_analyzer.GenerateMemoryReport(metagraph) - print(report) + if FLAGS.memory_report: + report = cost_analyzer.GenerateMemoryReport(metagraph) + print(report) if __name__ == "__main__": @@ -89,9 +97,7 @@ if __name__ == "__main__": "--fetch", type=str, default=None, - help= - "The name of the fetch node." - ) + help="The names of the fetch node delimited by comma.") parser.add_argument( "--rewriter_config", type=str, @@ -107,5 +113,9 @@ if __name__ == "__main__": help="Generate per-node report. By default the report contains stats " "aggregated on a per op type basis, per_node_report adds results " "for each individual node to the report.") + parser.add_argument( + "--memory_report", + action="store_true", + help="Generate memory usage report.") FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) |