aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/python/trt_convert.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/python/trt_convert.py')
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py319
1 files changed, 242 insertions, 77 deletions
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 4116f2fe30..369e73b5a6 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long
import six as _six
+# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
@@ -28,55 +28,179 @@ from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_vers
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
+# pylint: enable=unused-import,line-too-long
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
-# pylint: enable=unused-import,line-too-long
+
+if _six.PY2:
+ _to_bytes = lambda s: s
+ _to_string = lambda s: s
+else:
+ _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape")
+ _to_string = lambda s: s.decode("utf-8")
+
+
+class TrtPrecisionMode(object):
+ FP32 = "FP32"
+ FP16 = "FP16"
+ INT8 = "INT8"
+
+ @staticmethod
+ def supported_precision_modes():
+ return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
+
+
+def tensorrt_rewriter_config(max_batch_size=1,
+ max_workspace_size_bytes=2 << 20,
+ precision_mode=TrtPrecisionMode.FP32,
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batch_sizes=None):
+ """Returns a RewriterConfig proto for TRT transformation.
+
+ Args:
+ max_batch_size: max size for the input batch
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
+ minimum_segment_size: the minimum number of nodes required for a subgraph to
+ be replaced by TRTEngineOp.
+ is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
+ network and engine at run time.
+ maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batch_sizes: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be smaller than maximum_cached_engines, and the dynamic TRT op will
+ use this list to determine the batch sizes of the cached engines, instead
+ of making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+
+ Returns:
+ A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
+
+ Raises:
+ TypeError: if the provided precision mode is invalid.
+ ValueError: if len(cached_engine_batch_sizes) exceed maximum_cached_engines.
+ """
+ if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes():
+ raise ValueError(("precision mode '{}' is not supported."
+ "It should be one of {}").format(
+ precision_mode,
+ TrtPrecisionMode.supported_precision_modes))
+
+ rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+ rewriter_cfg.optimizers.extend(["constfold", "layout"])
+ optimizer = rewriter_cfg.custom_optimizers.add()
+ optimizer.name = "TensorRTOptimizer"
+ optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
+ optimizer.parameter_map["max_batch_size"].i = max_batch_size
+ optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ optimizer.parameter_map[
+ "max_workspace_size_bytes"].i = max_workspace_size_bytes
+ optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
+ optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
+ if cached_engine_batch_sizes:
+ if not isinstance(cached_engine_batch_sizes, list):
+ raise TypeError("cached_engine_batch_sizes should be a list.")
+ if len(cached_engine_batch_sizes) > maximum_cached_engines:
+ raise ValueError("cached_engine_batch_sizes should not contain more than "
+ "maximum_cached_engines items.")
+ optimizer.parameter_map["cached_engine_batches"].list.i.extend(
+ cached_engine_batch_sizes)
+ return rewriter_cfg
def create_inference_graph(input_graph_def,
outputs,
max_batch_size=1,
max_workspace_size_bytes=2 << 20,
- precision_mode="FP32",
+ precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=None):
+ cached_engine_batch_sizes=None,
+ input_saved_model_dir=None,
+ input_saved_model_tags=None,
+ output_saved_model_dir=None,
+ session_config=None):
"""Python wrapper for the TRT transformation.
Args:
- input_graph_def: GraphDef object containing a model to be transformed.
- outputs: list of tensors or node names for the model outputs.
- max_batch_size: max size for the input batch
- max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
- precision_mode: one of 'FP32', 'FP16' and 'INT8'
+ input_graph_def: a GraphDef object containing a model to be transformed. If
+ set to None, the graph will be read from the SavedModel loaded from
+ input_saved_model_dir.
+ outputs: list of tensors or node names for the model outputs. Only used when
+ input_graph_def is not None.
+ max_batch_size: max size for the input batch.
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph to
be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
- cached_engine_batches: batch sizes used to pre-create cached engines.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batch_sizes: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be smaller than maximum_cached_engines, and the dynamic TRT op will
+ use this list to determine the batch sizes of the cached engines, instead
+ of making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+ input_saved_model_dir: the directory to load the SavedModel which contains
+ the input graph to transforms. Used only when input_graph_def is None.
+ input_saved_model_tags: list of tags to load the SavedModel.
+ output_saved_model_dir: if not None, construct a SavedModel using the
+ returned GraphDef and save it to the specified directory. This option only
+ works when the input graph is loaded from a SavedModel, i.e. when
+ input_saved_model_dir is specified and input_graph_def is None.
+ session_config: the ConfigProto used to create a Session. If not specified,
+ a default ConfigProto will be used.
Returns:
- New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+ A GraphDef transformed from input_graph_def (or the SavedModel graph def
+ loaded from input_saved_model_dir, if input_graph_def is not present), where
+ all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
+ function is added for each of the subgraphs.
+
+ If is_dynamic_op is True, each TRTEngineOp will contain a serialized
+ subgraph GraphDef, which will be converted to a TRT engine at execution time
+ and the TRT engine will be cached for future usage. A new TRT engine will be
+ created each time when none of the cached engines match the input shapes. If
+ it fails to execute the TRT engine or the number of cached engines reaches
+ maximum_cached_engines, the op will fall back to call the corresponding TF
+ function.
+
+ If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
+ engine created from the corresponding subgraph. No more engines will be
+ created on the fly, and the op will fall back to call the corresponding TF
+ function when it fails to execute the engine.
Raises:
- ValueError: if the provided precision mode is invalid.
- RuntimeError: if the returned status message is malformed.
+ ValueError: if the combination of the parameters is invalid.
+ RuntimeError: if the TensorRT library version is incompatible.
"""
- supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
- if precision_mode.upper() not in supported_precision_modes:
- raise ValueError(("precision mode '{}' is not supported."
- "It should be one of {}").format(
- precision_mode, "{'FP32', 'FP16', 'INT8'}"))
- mode = supported_precision_modes[precision_mode.upper()]
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
version_mismatch = False
@@ -101,61 +225,111 @@ def create_inference_graph(input_graph_def,
tf_logging.info("Running against TensorRT version %s" % ".".join(
[str(x) for x in loaded_version]))
- def py2bytes(inp):
- return inp
+ if session_config is None:
+ session_config = config_pb2.ConfigProto()
+
+ if input_saved_model_tags is None:
+ input_saved_model_tags = [tag_constants.SERVING]
+ saved_model_loader = None
+ grappler_meta_graph_def = None
- def py3bytes(inp):
- return inp.encode("utf-8", errors="surrogateescape")
+ if input_graph_def is None:
+ # Read from SavedModel and freeze the graph if necessary.
+ if input_saved_model_dir is None:
+ raise ValueError("input_graph_def and input_saved_model_dir cannot be "
+ "both None")
+ with ops.Graph().as_default():
+ with session.Session(config=session_config) as sess:
+ saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir)
+ input_meta_graph_def = saved_model_loader.load(sess,
+ input_saved_model_tags)
+ output_node_names = set()
- def py2string(inp):
- return inp
+ def _gather_names(tensor_info):
+ """Get the node names from a TensorInfo."""
+ return set(
+ [tensor_info[key].name.split(":")[0] for key in tensor_info])
- def py3string(inp):
- return inp.decode("utf-8")
+ # Get input and outputs from all SignatureDef.
+ for key in input_meta_graph_def.signature_def:
+ signature_def = input_meta_graph_def.signature_def[key]
+ output_node_names.update(_gather_names(signature_def.inputs))
+ output_node_names.update(_gather_names(signature_def.outputs))
- if _six.PY2:
- to_bytes = py2bytes
- to_string = py2string
+ # Freeze the variables in the SavedModel graph and copy the frozen
+ # graph over.
+ frozen_graph_def = graph_util.convert_variables_to_constants(
+ sess, sess.graph.as_graph_def(add_shapes=True),
+ list(output_node_names))
+ grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
+
+ # Copy the collections that are not variables.
+ for key in input_meta_graph_def.collection_def:
+ # TODO(laigd): currently we use the collection key to filter out
+ # collections that depend on variable ops, but this may miss some
+ # other user-defined collections. A better way would be to use
+ # CollectionDef::NodeList for the filtering.
+ if key not in [
+ "variables", "local_variables", "model_variables",
+ "trainable_variables", "train_op", "table_initializer"
+ ]:
+ grappler_meta_graph_def.collection_def[key].CopyFrom(
+ input_meta_graph_def.collection_def[key])
+
+ # Copy other information.
+ grappler_meta_graph_def.meta_info_def.CopyFrom(
+ input_meta_graph_def.meta_info_def)
+ for key in input_meta_graph_def.signature_def:
+ grappler_meta_graph_def.signature_def[key].CopyFrom(
+ input_meta_graph_def.signature_def[key])
+ # TODO(laigd): maybe add back AssetFileDef.
else:
- to_bytes = py3bytes
- to_string = py3string
-
- # Create MetaGraphDef
- graph = ops.Graph()
- with graph.as_default():
- importer.import_graph_def(input_graph_def, name="")
- meta_graph = saver.export_meta_graph(
- graph_def=graph.as_graph_def(), graph=graph)
- if outputs:
- output_collection = meta_graph_pb2.CollectionDef()
- output_list = output_collection.node_list.value
- for i in outputs:
- if isinstance(i, ops.Tensor):
- output_list.append(to_bytes(i.name))
- else:
- output_list.append(to_bytes(i))
- meta_graph.collection_def["train_op"].CopyFrom(output_collection)
+ if output_saved_model_dir is not None:
+ raise ValueError("output_saved_model_dir cannot be set when "
+ "input_graph_def is set")
+ # Create MetaGraphDef from input graph.
+ graph = ops.Graph()
+ with graph.as_default():
+ importer.import_graph_def(input_graph_def, name="")
+ grappler_meta_graph_def = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
+ if outputs:
+ output_collection = meta_graph_pb2.CollectionDef()
+ output_list = output_collection.node_list.value
+ for i in outputs:
+ if isinstance(i, ops.Tensor):
+ output_list.append(_to_bytes(i.name))
+ else:
+ output_list.append(_to_bytes(i))
+ # TODO(laigd): use another key as the outputs are really not train_op.
+ grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
+ output_collection)
# Create RewriterConfig.
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- optimizer = rewriter_cfg.custom_optimizers.add()
- optimizer.name = "TensorRTOptimizer"
- optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
- optimizer.parameter_map["max_batch_size"].i = max_batch_size
- optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
- optimizer.parameter_map[
- "max_workspace_size_bytes"].i = max_workspace_size_bytes
- optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode)
- optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
- if cached_engine_batches:
- if not isinstance(cached_engine_batches, list):
- raise TypeError("cached_engine_batches should be a list.")
- optimizer.parameter_map["cached_engine_batches"].list.i.extend(
- cached_engine_batches)
+ rewriter_cfg = tensorrt_rewriter_config(
+ max_batch_size, max_workspace_size_bytes, precision_mode,
+ minimum_segment_size, is_dynamic_op, maximum_cached_engines,
+ cached_engine_batch_sizes)
+
+ # Run Grappler.
+ transformed_graph_def = tf_optimizer.OptimizeGraph(
+ rewriter_cfg, grappler_meta_graph_def, graph_id=b"tf_graph")
- return tf_optimizer.OptimizeGraph(
- rewriter_cfg, meta_graph, graph_id=b"tf_graph")
+ # Optionally write the transformed graphdef as SavedModel.
+ if output_saved_model_dir is not None:
+ saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
+ with ops.Graph().as_default():
+ importer.import_graph_def(transformed_graph_def, name="")
+ with session.Session(config=session_config) as sess:
+ saved_model_builder.add_meta_graph_and_variables(
+ sess,
+ input_saved_model_tags,
+ signature_def_map=grappler_meta_graph_def.signature_def)
+ # Ignore other meta graphs from the input SavedModel.
+ saved_model_builder.save()
+
+ return transformed_graph_def
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
@@ -164,22 +338,13 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
Args:
calibration_graph_def: the calibration GraphDef object with calibration data
is_dynamic_op: whether to create dynamic static engines from calibration
+
Returns:
New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
Raises:
RuntimeError: if the returned status message is malformed.
"""
- def py2string(inp):
- return inp
-
- def py3string(inp):
- return inp.decode("utf-8")
-
- if _six.PY2:
- to_string = py2string
- else:
- to_string = py3string
is_calib_graph = False
for n in calibration_graph_def.node:
if n.op == "TRTEngineOp":
@@ -190,7 +355,7 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
return None
graph_str = calibration_graph_def.SerializeToString()
out = calib_convert(graph_str, is_dynamic_op)
- status = to_string(out[0])
+ status = _to_string(out[0])
output_graph_def_string = out[1]
del graph_str # Save some memory
if len(status) < 2: