aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/grappler
diff options
context:
space:
mode:
authorGravatar Yao Zhang <yaozhang@google.com>2018-05-03 12:19:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-03 13:37:14 -0700
commite5854637cc3f8099586f18ed144fd6d4f90a6fc7 (patch)
tree0d2f8c3625695f7a729df916ee9af50d666e400a /tensorflow/python/grappler
parenta16ba4fc0d3faec077c689f3f361264978a2d3cb (diff)
Simplify file reading and support SavedModel.
PiperOrigin-RevId: 195291836
Diffstat (limited to 'tensorflow/python/grappler')
-rw-r--r--tensorflow/python/grappler/cost_analyzer_tool.py75
1 files changed, 41 insertions, 34 deletions
diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py
index 0853db2524..e6229e1856 100644
--- a/tensorflow/python/grappler/cost_analyzer_tool.py
+++ b/tensorflow/python/grappler/cost_analyzer_tool.py
@@ -21,11 +21,13 @@ from __future__ import print_function
import argparse
import sys
+from google.protobuf import message
from google.protobuf import text_format
from tensorflow.contrib.fused_conv.ops import gen_fused_conv2d_bias_activation_op # pylint: disable=unused-import
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import cost_analyzer
@@ -37,33 +39,42 @@ from tensorflow.python.training import saver
def get_metagraph():
"""Constructs and returns a MetaGraphDef from the input file."""
- if FLAGS.metagraphdef:
- with gfile.GFile(FLAGS.metagraphdef) as meta_file:
- metagraph = meta_graph_pb2.MetaGraphDef()
- if FLAGS.metagraphdef.endswith(".pbtxt"):
- text_format.Merge(meta_file.read(), metagraph)
- else:
- metagraph.ParseFromString(meta_file.read())
- if FLAGS.fetch is not None:
- fetch_collection = meta_graph_pb2.CollectionDef()
- for fetch in FLAGS.fetch.split(","):
- fetch_collection.node_list.value.append(fetch)
- metagraph.collection_def["train_op"].CopyFrom(fetch_collection)
- else:
- with gfile.GFile(FLAGS.graphdef) as graph_file:
- graph_def = graph_pb2.GraphDef()
- if FLAGS.graphdef.endswith(".pbtxt"):
- text_format.Merge(graph_file.read(), graph_def)
- else:
- graph_def.ParseFromString(graph_file.read())
- importer.import_graph_def(graph_def, name="")
- graph = ops.get_default_graph()
- for fetch in FLAGS.fetch.split(","):
- fetch_op = graph.get_operation_by_name(fetch)
- graph.add_to_collection("train_op", fetch_op)
- metagraph = saver.export_meta_graph(
- graph_def=graph.as_graph_def(), graph=graph)
- return metagraph
+ with gfile.GFile(FLAGS.input) as input_file:
+ input_data = input_file.read()
+ try:
+ saved_model = saved_model_pb2.SavedModel()
+ text_format.Merge(input_data, saved_model)
+ meta_graph = saved_model.meta_graphs[0]
+ except text_format.ParseError:
+ try:
+ saved_model.ParseFromString(input_data)
+ meta_graph = saved_model.meta_graphs[0]
+ except message.DecodeError:
+ try:
+ meta_graph = meta_graph_pb2.MetaGraphDef()
+ text_format.Merge(input_data, meta_graph)
+ except text_format.ParseError:
+ try:
+ meta_graph.ParseFromString(input_data)
+ except message.DecodeError:
+ try:
+ graph_def = graph_pb2.GraphDef()
+ text_format.Merge(input_data, graph_def)
+ except text_format.ParseError:
+ try:
+ graph_def.ParseFromString(input_data)
+ except message.DecodeError:
+ raise ValueError("Invalid input file.")
+ importer.import_graph_def(graph_def, name="")
+ graph = ops.get_default_graph()
+ meta_graph = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(), graph=graph)
+ if FLAGS.fetch is not None:
+ fetch_collection = meta_graph_pb2.CollectionDef()
+ for fetch in FLAGS.fetch.split(","):
+ fetch_collection.node_list.value.append(fetch)
+ meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
+ return meta_graph
def main(_):
@@ -85,15 +96,11 @@ def main(_):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
- "--metagraphdef",
+ "--input",
type=str,
default=None,
- help="Input .meta MetaGraphDef file path.")
- parser.add_argument(
- "--graphdef",
- type=str,
- default=None,
- help="Input .pb GraphDef file path.")
+ help="Input file path. Accept SavedModel, MetaGraphDef, and GraphDef in "
+ "either binary or text format.")
parser.add_argument(
"--fetch",
type=str,