diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-02-16 10:06:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-16 10:10:06 -0800 |
commit | e31018eab5b30231eeddc9b29a2aca1f1cbb050f (patch) | |
tree | 5ee8463d7b072810e500df319231c0cabc3fd695 /tensorflow/python/grappler | |
parent | d2e8a32971bfea647f1c703840ef9a09f728b3f2 (diff) |
Made cost_analyzer_tool accept fetch nodes when running with metagraph option. Also made it read metagraph in either binary or text format.
PiperOrigin-RevId: 186010810
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r-- | tensorflow/python/grappler/cost_analyzer_tool.py | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 51b77b471b..86db87d515 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -39,7 +39,14 @@ def main(_): if FLAGS.metagraphdef: with gfile.GFile(FLAGS.metagraphdef) as meta_file: metagraph = meta_graph_pb2.MetaGraphDef() - metagraph.ParseFromString(meta_file.read()) + if FLAGS.metagraphdef.endswith(".pbtxt"): + text_format.Merge(meta_file.read(), metagraph) + else: + metagraph.ParseFromString(meta_file.read()) + if FLAGS.fetch is not None: + fetch_collection = meta_graph_pb2.CollectionDef() + fetch_collection.node_list.value.append(FLAGS.fetch) + metagraph.collection_def["train_op"].CopyFrom(fetch_collection) else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() @@ -78,15 +85,12 @@ if __name__ == "__main__": type=str, default=None, help="Input .pb GraphDef file path.") - # Consider making flag fetch work together with flag metagraphdef. As some - # MetaGraphDef files don't have collection train_op. parser.add_argument( "--fetch", type=str, default=None, help= - "The name of the fetch node. This flag is ignored if flag " - "metagraphdef is used." + "The name of the fetch node." ) parser.add_argument( "--rewriter_config", |