aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-01-25 17:27:42 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 17:31:34 -0800
commit8220c228ab066d23ddf506a58bb08b1694239a34 (patch)
treeef1d3b5e935740158b0942987537496795b20d26
parentfbd3e8a2c01d83a6aa6cca044fe5678d20035451 (diff)
Add an option to input a GraphDef.
PiperOrigin-RevId: 183317862
-rw-r--r--tensorflow/python/grappler/cost_analyzer_tool.py41
1 files changed, 37 insertions, 4 deletions
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py
index 146bb4311c..61dc4e2afb 100644
--- a/tensorflow/python/grappler/cost_analyzer_tool.py
+++ b/tensorflow/python/grappler/cost_analyzer_tool.py
@@ -23,18 +23,33 @@ import sys
from google.protobuf import text_format
+from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
from tensorflow.python.grappler import cost_analyzer
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
+from tensorflow.python.training import saver
def main(_):
- with gfile.GFile(FLAGS.input) as input_file:
- metagraph = meta_graph_pb2.MetaGraphDef()
- metagraph.ParseFromString(input_file.read())
+ if FLAGS.metagraphdef:
+ with gfile.GFile(FLAGS.metagraphdef) as meta_file:
+ metagraph = meta_graph_pb2.MetaGraphDef()
+ metagraph.ParseFromString(meta_file.read())
+ else:
+ with gfile.GFile(FLAGS.graphdef) as graph_file:
+ graph_def = graph_pb2.GraphDef()
+ graph_def.ParseFromString(graph_file.read())
+ importer.import_graph_def(graph_def, name="")
+ graph = ops.get_default_graph()
+ fetch = graph.get_operation_by_name(FLAGS.fetch)
+ graph.add_to_collection("train_op", fetch)
+ metagraph = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(), graph=graph)
if FLAGS.rewriter_config is not None:
rewriter_config = rewriter_config_pb2.RewriterConfig()
@@ -49,7 +64,25 @@ def main(_):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- "--input", type=str, default=None, help="Input .meta file path.")
+ "--metagraphdef",
+ type=str,
+ default=None,
+ help="Input .meta MetaGraphDef file path.")
+ parser.add_argument(
+ "--graphdef",
+ type=str,
+ default=None,
+ help="Input .pb GraphDef file path.")
+ # Consider making flag fetch work together with flag metagraphdef. As some
+ # MetaGraphDef files don't have collection train_op.
+ parser.add_argument(
+ "--fetch",
+ type=str,
+ default=None,
+ help=
+ "The name of the fetch node. This flag is ignored if flag "
+ "metagraphdef is used."
+ )
parser.add_argument(
"--rewriter_config",
type=str,