aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-02-20 11:03:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-20 11:16:18 -0800
commit10386781aebfacd5366bf6af9fc40db35625232e (patch)
treec815d53a4a247155c0a88d233ffbfb2f182dff7e /tensorflow/python/grappler
parent65ac3dfa9a48d209edd50178b7477bbfe0435633 (diff)
Support multiple fetch nodes and add a flag for memory report.
PiperOrigin-RevId: 186329308
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/cost_analyzer_tool.py28
1 files changed, 19 insertions, 9 deletions
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py
index 86db87d515..0db3c30a27 100644
--- a/tensorflow/python/grappler/cost_analyzer_tool.py
+++ b/tensorflow/python/grappler/cost_analyzer_tool.py
@@ -35,7 +35,8 @@ from tensorflow.python.platform import gfile
from tensorflow.python.training import saver
-def main(_):
+def get_metagraph():
+ """Constructs and returns a MetaGraphDef from the input file."""
if FLAGS.metagraphdef:
with gfile.GFile(FLAGS.metagraphdef) as meta_file:
metagraph = meta_graph_pb2.MetaGraphDef()
@@ -45,7 +46,8 @@ def main(_):
metagraph.ParseFromString(meta_file.read())
if FLAGS.fetch is not None:
fetch_collection = meta_graph_pb2.CollectionDef()
- fetch_collection.node_list.value.append(FLAGS.fetch)
+ for fetch in FLAGS.fetch.split(","):
+ fetch_collection.node_list.value.append(fetch)
metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
else:
with gfile.GFile(FLAGS.graphdef) as graph_file:
@@ -56,11 +58,16 @@ def main(_):
graph_def.ParseFromString(graph_file.read())
importer.import_graph_def(graph_def, name="")
graph = ops.get_default_graph()
- fetch = graph.get_operation_by_name(FLAGS.fetch)
- graph.add_to_collection("train_op", fetch)
+ for fetch in FLAGS.fetch.split(","):
+ fetch_op = graph.get_operation_by_name(fetch)
+ graph.add_to_collection("train_op", fetch_op)
metagraph = saver.export_meta_graph(
graph_def=graph.as_graph_def(), graph=graph)
+ return metagraph
+
+def main(_):
+ metagraph = get_metagraph()
rewriter_config = rewriter_config_pb2.RewriterConfig()
if FLAGS.rewriter_config is not None:
text_format.Merge(FLAGS.rewriter_config, rewriter_config)
@@ -69,8 +76,9 @@ def main(_):
report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report)
print(report)
- report = cost_analyzer.GenerateMemoryReport(metagraph)
- print(report)
+ if FLAGS.memory_report:
+ report = cost_analyzer.GenerateMemoryReport(metagraph)
+ print(report)
if __name__ == "__main__":
@@ -89,9 +97,7 @@ if __name__ == "__main__":
"--fetch",
type=str,
default=None,
- help=
- "The name of the fetch node."
- )
+ help="The names of the fetch node delimited by comma.")
parser.add_argument(
"--rewriter_config",
type=str,
@@ -107,5 +113,9 @@ if __name__ == "__main__":
help="Generate per-node report. By default the report contains stats "
"aggregated on a per op type basis, per_node_report adds results "
"for each individual node to the report.")
+ parser.add_argument(
+ "--memory_report",
+ action="store_true",
+ help="Generate memory usage report.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)