diff options
author | Yao Zhang <yaozhang@google.com> | 2018-01-25 17:27:42 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-25 17:31:34 -0800 |
commit | 8220c228ab066d23ddf506a58bb08b1694239a34 (patch) | |
tree | ef1d3b5e935740158b0942987537496795b20d26 | |
parent | fbd3e8a2c01d83a6aa6cca044fe5678d20035451 (diff) |
Add an option to input a GraphDef.
PiperOrigin-RevId: 183317862
-rw-r--r-- | tensorflow/python/grappler/cost_analyzer_tool.py | 41 |
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 146bb4311c..61dc4e2afb 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -23,18 +23,33 @@ import sys from google.protobuf import text_format +from tensorflow.core.framework import graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops from tensorflow.python.grappler import cost_analyzer from tensorflow.python.grappler import tf_optimizer from tensorflow.python.platform import app from tensorflow.python.platform import gfile +from tensorflow.python.training import saver def main(_): - with gfile.GFile(FLAGS.input) as input_file: - metagraph = meta_graph_pb2.MetaGraphDef() - metagraph.ParseFromString(input_file.read()) + if FLAGS.metagraphdef: + with gfile.GFile(FLAGS.metagraphdef) as meta_file: + metagraph = meta_graph_pb2.MetaGraphDef() + metagraph.ParseFromString(meta_file.read()) + else: + with gfile.GFile(FLAGS.graphdef) as graph_file: + graph_def = graph_pb2.GraphDef() + graph_def.ParseFromString(graph_file.read()) + importer.import_graph_def(graph_def, name="") + graph = ops.get_default_graph() + fetch = graph.get_operation_by_name(FLAGS.fetch) + graph.add_to_collection("train_op", fetch) + metagraph = saver.export_meta_graph( + graph_def=graph.as_graph_def(), graph=graph) if FLAGS.rewriter_config is not None: rewriter_config = rewriter_config_pb2.RewriterConfig() @@ -49,7 +64,25 @@ def main(_): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( - "--input", type=str, default=None, help="Input .meta file path.") + "--metagraphdef", + type=str, + default=None, + help="Input .meta MetaGraphDef file path.") + parser.add_argument( + "--graphdef", + type=str, + default=None, + help="Input .pb GraphDef file path.") + # Consider making flag fetch work together with flag metagraphdef. As some + # MetaGraphDef files don't have collection train_op. + parser.add_argument( + "--fetch", + type=str, + default=None, + help= + "The name of the fetch node. This flag is ignored if flag " + "metagraphdef is used." + ) parser.add_argument( "--rewriter_config", type=str, |