aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-16 10:06:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 10:10:06 -0800
commite31018eab5b30231eeddc9b29a2aca1f1cbb050f (patch)
tree5ee8463d7b072810e500df319231c0cabc3fd695 /tensorflow/python/grappler
parentd2e8a32971bfea647f1c703840ef9a09f728b3f2 (diff)
Made cost_analyzer_tool accept fetch nodes when running with metagraph option. Also made it read metagraph in either binary or text format.
PiperOrigin-RevId: 186010810
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/cost_analyzer_tool.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py
index 51b77b471b..86db87d515 100644
--- a/tensorflow/python/grappler/cost_analyzer_tool.py
+++ b/tensorflow/python/grappler/cost_analyzer_tool.py
@@ -39,7 +39,14 @@ def main(_):
if FLAGS.metagraphdef:
with gfile.GFile(FLAGS.metagraphdef) as meta_file:
metagraph = meta_graph_pb2.MetaGraphDef()
- metagraph.ParseFromString(meta_file.read())
+ if FLAGS.metagraphdef.endswith(".pbtxt"):
+ text_format.Merge(meta_file.read(), metagraph)
+ else:
+ metagraph.ParseFromString(meta_file.read())
+ if FLAGS.fetch is not None:
+ fetch_collection = meta_graph_pb2.CollectionDef()
+ fetch_collection.node_list.value.append(FLAGS.fetch)
+ metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
else:
with gfile.GFile(FLAGS.graphdef) as graph_file:
graph_def = graph_pb2.GraphDef()
@@ -78,15 +85,12 @@ if __name__ == "__main__":
type=str,
default=None,
help="Input .pb GraphDef file path.")
- # Consider making flag fetch work together with flag metagraphdef. As some
- # MetaGraphDef files don't have collection train_op.
parser.add_argument(
"--fetch",
type=str,
default=None,
help=
- "The name of the fetch node. This flag is ignored if flag "
- "metagraphdef is used."
+ "The name of the fetch node."
)
parser.add_argument(
"--rewriter_config",