aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler/cost_analyzer_tool.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/grappler/cost_analyzer_tool.py')
-rw-r--r--tensorflow/python/grappler/cost_analyzer_tool.py41
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py
index 146bb4311c..61dc4e2afb 100644
--- a/tensorflow/python/grappler/cost_analyzer_tool.py
+++ b/tensorflow/python/grappler/cost_analyzer_tool.py
@@ -23,18 +23,33 @@ import sys
from google.protobuf import text_format
+from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
from tensorflow.python.grappler import cost_analyzer
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
+from tensorflow.python.training import saver
def main(_):
- with gfile.GFile(FLAGS.input) as input_file:
- metagraph = meta_graph_pb2.MetaGraphDef()
- metagraph.ParseFromString(input_file.read())
+ if FLAGS.metagraphdef:
+ with gfile.GFile(FLAGS.metagraphdef) as meta_file:
+ metagraph = meta_graph_pb2.MetaGraphDef()
+ metagraph.ParseFromString(meta_file.read())
+ else:
+ with gfile.GFile(FLAGS.graphdef) as graph_file:
+ graph_def = graph_pb2.GraphDef()
+ graph_def.ParseFromString(graph_file.read())
+ importer.import_graph_def(graph_def, name="")
+ graph = ops.get_default_graph()
+ fetch = graph.get_operation_by_name(FLAGS.fetch)
+ graph.add_to_collection("train_op", fetch)
+ metagraph = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(), graph=graph)
if FLAGS.rewriter_config is not None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
@@ -49,7 +64,25 @@ def main(_):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- "--input", type=str, default=None, help="Input .meta file path.")
+ "--metagraphdef",
+ type=str,
+ default=None,
+ help="Input .meta MetaGraphDef file path.")
+ parser.add_argument(
+ "--graphdef",
+ type=str,
+ default=None,
+ help="Input .pb GraphDef file path.")
+ # Consider making flag fetch work together with flag metagraphdef. As some
+ # MetaGraphDef files don't have collection train_op.
+ parser.add_argument(
+ "--fetch",
+ type=str,
+ default=None,
+ help=
+ "The name of the fetch node. This flag is ignored if flag "
+ "metagraphdef is used."
+ )
parser.add_argument(
"--rewriter_config",
type=str,