diff options
author | Guangda Lai <laigd@google.com> | 2018-09-12 19:05:02 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-12 19:08:27 -0700 |
commit | 28ede9ed7caee0ce2731d95cc0eb9aff7f360105 (patch) | |
tree | b71838ce0eac5bd833bd4b2555cec9428ca3c87b /tensorflow/contrib/tensorrt | |
parent | 30b711b07570b12c8880532aede428503c35e310 (diff) |
Add SavedModel support to TensorRT's create_inference_graph() API.
PiperOrigin-RevId: 212743550
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 31 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/python/trt_convert.py | 319 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/python/trt_convert_test.py | 293 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/test/test_tftrt.py | 6 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py | 28 |
5 files changed, 577 insertions, 100 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 122a67a407..9e8979bce4 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -19,6 +19,7 @@ load( "tf_gen_op_libs", "tf_gen_op_wrapper_py", ) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "cuda_py_tests") load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -181,7 +182,12 @@ py_library( srcs_version = "PY2AND3", deps = [ ":wrap_conversion", + "//tensorflow/python:graph_util", + "//tensorflow/python:session", "//tensorflow/python:tf_optimizer", + "//tensorflow/python/saved_model:builder", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", ], ) @@ -410,6 +416,31 @@ py_library( ], ) +cuda_py_test( + name = "trt_convert_test", + srcs = ["python/trt_convert_test.py"], + additional_deps = [ + ":trt_convert_py", + ":trt_ops_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:graph_util", + "//tensorflow/python/saved_model:builder", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow/python/saved_model:tag_constants", + "//tensorflow/python/saved_model:utils", + "//tensorflow/python/tools:freeze_graph_lib", + "//tensorflow/python/tools:saved_model_utils", + ], + tags = [ + "no_cuda_on_cpu_tap", + "no_windows", + "nomac", + ], +) + cuda_py_tests( name = "tf_trt_integration_test", srcs = [ diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py index 4116f2fe30..369e73b5a6 100644 --- a/tensorflow/contrib/tensorrt/python/trt_convert.py +++ b/tensorflow/contrib/tensorrt/python/trt_convert.py @@ -18,8 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import,line-too-long import six as _six +# pylint: disable=unused-import,line-too-long from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values @@ -28,55 +28,179 @@ from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_vers from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled +# pylint: enable=unused-import,line-too-long from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import errors_impl as _impl +from tensorflow.python.framework import graph_util from tensorflow.python.framework import importer from tensorflow.python.framework import ops from tensorflow.python.grappler import tf_optimizer from tensorflow.python.platform import tf_logging +from tensorflow.python.saved_model import builder +from tensorflow.python.saved_model import loader_impl +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import saver -# pylint: enable=unused-import,line-too-long + +if _six.PY2: + _to_bytes = lambda s: s + _to_string = lambda s: s +else: + _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape") + _to_string = lambda s: s.decode("utf-8") + + +class TrtPrecisionMode(object): + FP32 = "FP32" + FP16 = "FP16" + INT8 = "INT8" + + @staticmethod + def supported_precision_modes(): + return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8] + + +def tensorrt_rewriter_config(max_batch_size=1, + max_workspace_size_bytes=2 << 20, + precision_mode=TrtPrecisionMode.FP32, + minimum_segment_size=3, + is_dynamic_op=False, + maximum_cached_engines=1, + cached_engine_batch_sizes=None): + """Returns a RewriterConfig proto for TRT transformation. + + Args: + max_batch_size: max size for the input batch + max_workspace_size_bytes: the maximum GPU temporary memory which the TRT + engine can use at execution time. This corresponds to the 'workspaceSize' + parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). + precision_mode: one of TrtPrecisionMode.supported_precision_modes(). + minimum_segment_size: the minimum number of nodes required for a subgraph to + be replaced by TRTEngineOp. + is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT + network and engine at run time. + maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. + If the number of cached engines is already at max but none of them can + serve the input, the TRTEngineOp will fall back to run the TF function + based on which the TRTEngineOp is created. + cached_engine_batch_sizes: a list of batch sizes used to create cached + engines, only used when is_dynamic_op is True. The length of the list + should be smaller than maximum_cached_engines, and the dynamic TRT op will + use this list to determine the batch sizes of the cached engines, instead + of making the decision on the fly. This is useful when we know the most + common batch size(s) the application is going to generate. + + Returns: + A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler. + + Raises: + TypeError: if the provided precision mode is invalid. + ValueError: if len(cached_engine_batch_sizes) exceed maximum_cached_engines. + """ + if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(): + raise ValueError(("precision mode '{}' is not supported." + "It should be one of {}").format( + precision_mode, + TrtPrecisionMode.supported_precision_modes)) + + rewriter_cfg = rewriter_config_pb2.RewriterConfig() + rewriter_cfg.optimizers.extend(["constfold", "layout"]) + optimizer = rewriter_cfg.custom_optimizers.add() + optimizer.name = "TensorRTOptimizer" + optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size + optimizer.parameter_map["max_batch_size"].i = max_batch_size + optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op + optimizer.parameter_map[ + "max_workspace_size_bytes"].i = max_workspace_size_bytes + optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode) + optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines + if cached_engine_batch_sizes: + if not isinstance(cached_engine_batch_sizes, list): + raise TypeError("cached_engine_batch_sizes should be a list.") + if len(cached_engine_batch_sizes) > maximum_cached_engines: + raise ValueError("cached_engine_batch_sizes should not contain more than " + "maximum_cached_engines items.") + optimizer.parameter_map["cached_engine_batches"].list.i.extend( + cached_engine_batch_sizes) + return rewriter_cfg def create_inference_graph(input_graph_def, outputs, max_batch_size=1, max_workspace_size_bytes=2 << 20, - precision_mode="FP32", + precision_mode=TrtPrecisionMode.FP32, minimum_segment_size=3, is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=None): + cached_engine_batch_sizes=None, + input_saved_model_dir=None, + input_saved_model_tags=None, + output_saved_model_dir=None, + session_config=None): """Python wrapper for the TRT transformation. Args: - input_graph_def: GraphDef object containing a model to be transformed. - outputs: list of tensors or node names for the model outputs. - max_batch_size: max size for the input batch - max_workspace_size_bytes: parameter to control memory allocation (in Bytes) - precision_mode: one of 'FP32', 'FP16' and 'INT8' + input_graph_def: a GraphDef object containing a model to be transformed. If + set to None, the graph will be read from the SavedModel loaded from + input_saved_model_dir. + outputs: list of tensors or node names for the model outputs. Only used when + input_graph_def is not None. + max_batch_size: max size for the input batch. + max_workspace_size_bytes: the maximum GPU temporary memory which the TRT + engine can use at execution time. This corresponds to the 'workspaceSize' + parameter of nvinfer1::IBuilder::setMaxWorkspaceSize(). + precision_mode: one of TrtPrecisionMode.supported_precision_modes(). minimum_segment_size: the minimum number of nodes required for a subgraph to be replaced by TRTEngineOp. is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT network and engine at run time. maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops. - cached_engine_batches: batch sizes used to pre-create cached engines. + If the number of cached engines is already at max but none of them can + serve the input, the TRTEngineOp will fall back to run the TF function + based on which the TRTEngineOp is created. + cached_engine_batch_sizes: a list of batch sizes used to create cached + engines, only used when is_dynamic_op is True. The length of the list + should be smaller than maximum_cached_engines, and the dynamic TRT op will + use this list to determine the batch sizes of the cached engines, instead + of making the decision on the fly. This is useful when we know the most + common batch size(s) the application is going to generate. + input_saved_model_dir: the directory to load the SavedModel which contains + the input graph to transforms. Used only when input_graph_def is None. + input_saved_model_tags: list of tags to load the SavedModel. + output_saved_model_dir: if not None, construct a SavedModel using the + returned GraphDef and save it to the specified directory. This option only + works when the input graph is loaded from a SavedModel, i.e. when + input_saved_model_dir is specified and input_graph_def is None. + session_config: the ConfigProto used to create a Session. If not specified, + a default ConfigProto will be used. Returns: - New GraphDef with TRTEngineOps placed in graph replacing subgraphs. + A GraphDef transformed from input_graph_def (or the SavedModel graph def + loaded from input_saved_model_dir, if input_graph_def is not present), where + all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF + function is added for each of the subgraphs. + + If is_dynamic_op is True, each TRTEngineOp will contain a serialized + subgraph GraphDef, which will be converted to a TRT engine at execution time + and the TRT engine will be cached for future usage. A new TRT engine will be + created each time when none of the cached engines match the input shapes. If + it fails to execute the TRT engine or the number of cached engines reaches + maximum_cached_engines, the op will fall back to call the corresponding TF + function. + + If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT + engine created from the corresponding subgraph. No more engines will be + created on the fly, and the op will fall back to call the corresponding TF + function when it fails to execute the engine. Raises: - ValueError: if the provided precision mode is invalid. - RuntimeError: if the returned status message is malformed. + ValueError: if the combination of the parameters is invalid. + RuntimeError: if the TensorRT library version is incompatible. """ - supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2} - if precision_mode.upper() not in supported_precision_modes: - raise ValueError(("precision mode '{}' is not supported." - "It should be one of {}").format( - precision_mode, "{'FP32', 'FP16', 'INT8'}")) - mode = supported_precision_modes[precision_mode.upper()] compiled_version = get_linked_tensorrt_version() loaded_version = get_loaded_tensorrt_version() version_mismatch = False @@ -101,61 +225,111 @@ def create_inference_graph(input_graph_def, tf_logging.info("Running against TensorRT version %s" % ".".join( [str(x) for x in loaded_version])) - def py2bytes(inp): - return inp + if session_config is None: + session_config = config_pb2.ConfigProto() + + if input_saved_model_tags is None: + input_saved_model_tags = [tag_constants.SERVING] + saved_model_loader = None + grappler_meta_graph_def = None - def py3bytes(inp): - return inp.encode("utf-8", errors="surrogateescape") + if input_graph_def is None: + # Read from SavedModel and freeze the graph if necessary. + if input_saved_model_dir is None: + raise ValueError("input_graph_def and input_saved_model_dir cannot be " + "both None") + with ops.Graph().as_default(): + with session.Session(config=session_config) as sess: + saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir) + input_meta_graph_def = saved_model_loader.load(sess, + input_saved_model_tags) + output_node_names = set() - def py2string(inp): - return inp + def _gather_names(tensor_info): + """Get the node names from a TensorInfo.""" + return set( + [tensor_info[key].name.split(":")[0] for key in tensor_info]) - def py3string(inp): - return inp.decode("utf-8") + # Get input and outputs from all SignatureDef. + for key in input_meta_graph_def.signature_def: + signature_def = input_meta_graph_def.signature_def[key] + output_node_names.update(_gather_names(signature_def.inputs)) + output_node_names.update(_gather_names(signature_def.outputs)) - if _six.PY2: - to_bytes = py2bytes - to_string = py2string + # Freeze the variables in the SavedModel graph and copy the frozen + # graph over. + frozen_graph_def = graph_util.convert_variables_to_constants( + sess, sess.graph.as_graph_def(add_shapes=True), + list(output_node_names)) + grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef() + grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def) + + # Copy the collections that are not variables. + for key in input_meta_graph_def.collection_def: + # TODO(laigd): currently we use the collection key to filter out + # collections that depend on variable ops, but this may miss some + # other user-defined collections. A better way would be to use + # CollectionDef::NodeList for the filtering. + if key not in [ + "variables", "local_variables", "model_variables", + "trainable_variables", "train_op", "table_initializer" + ]: + grappler_meta_graph_def.collection_def[key].CopyFrom( + input_meta_graph_def.collection_def[key]) + + # Copy other information. + grappler_meta_graph_def.meta_info_def.CopyFrom( + input_meta_graph_def.meta_info_def) + for key in input_meta_graph_def.signature_def: + grappler_meta_graph_def.signature_def[key].CopyFrom( + input_meta_graph_def.signature_def[key]) + # TODO(laigd): maybe add back AssetFileDef. else: - to_bytes = py3bytes - to_string = py3string - - # Create MetaGraphDef - graph = ops.Graph() - with graph.as_default(): - importer.import_graph_def(input_graph_def, name="") - meta_graph = saver.export_meta_graph( - graph_def=graph.as_graph_def(), graph=graph) - if outputs: - output_collection = meta_graph_pb2.CollectionDef() - output_list = output_collection.node_list.value - for i in outputs: - if isinstance(i, ops.Tensor): - output_list.append(to_bytes(i.name)) - else: - output_list.append(to_bytes(i)) - meta_graph.collection_def["train_op"].CopyFrom(output_collection) + if output_saved_model_dir is not None: + raise ValueError("output_saved_model_dir cannot be set when " + "input_graph_def is set") + # Create MetaGraphDef from input graph. + graph = ops.Graph() + with graph.as_default(): + importer.import_graph_def(input_graph_def, name="") + grappler_meta_graph_def = saver.export_meta_graph( + graph_def=graph.as_graph_def(add_shapes=True), graph=graph) + if outputs: + output_collection = meta_graph_pb2.CollectionDef() + output_list = output_collection.node_list.value + for i in outputs: + if isinstance(i, ops.Tensor): + output_list.append(_to_bytes(i.name)) + else: + output_list.append(_to_bytes(i)) + # TODO(laigd): use another key as the outputs are really not train_op. + grappler_meta_graph_def.collection_def["train_op"].CopyFrom( + output_collection) # Create RewriterConfig. - rewriter_cfg = rewriter_config_pb2.RewriterConfig() - rewriter_cfg.optimizers.extend(["constfold", "layout"]) - optimizer = rewriter_cfg.custom_optimizers.add() - optimizer.name = "TensorRTOptimizer" - optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size - optimizer.parameter_map["max_batch_size"].i = max_batch_size - optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op - optimizer.parameter_map[ - "max_workspace_size_bytes"].i = max_workspace_size_bytes - optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode) - optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines - if cached_engine_batches: - if not isinstance(cached_engine_batches, list): - raise TypeError("cached_engine_batches should be a list.") - optimizer.parameter_map["cached_engine_batches"].list.i.extend( - cached_engine_batches) + rewriter_cfg = tensorrt_rewriter_config( + max_batch_size, max_workspace_size_bytes, precision_mode, + minimum_segment_size, is_dynamic_op, maximum_cached_engines, + cached_engine_batch_sizes) + + # Run Grappler. + transformed_graph_def = tf_optimizer.OptimizeGraph( + rewriter_cfg, grappler_meta_graph_def, graph_id=b"tf_graph") - return tf_optimizer.OptimizeGraph( - rewriter_cfg, meta_graph, graph_id=b"tf_graph") + # Optionally write the transformed graphdef as SavedModel. + if output_saved_model_dir is not None: + saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir) + with ops.Graph().as_default(): + importer.import_graph_def(transformed_graph_def, name="") + with session.Session(config=session_config) as sess: + saved_model_builder.add_meta_graph_and_variables( + sess, + input_saved_model_tags, + signature_def_map=grappler_meta_graph_def.signature_def) + # Ignore other meta graphs from the input SavedModel. + saved_model_builder.save() + + return transformed_graph_def def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): @@ -164,22 +338,13 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): Args: calibration_graph_def: the calibration GraphDef object with calibration data is_dynamic_op: whether to create dynamic static engines from calibration + Returns: New GraphDef with TRTEngineOps placed in graph replacing calibration nodes. Raises: RuntimeError: if the returned status message is malformed. """ - def py2string(inp): - return inp - - def py3string(inp): - return inp.decode("utf-8") - - if _six.PY2: - to_string = py2string - else: - to_string = py3string is_calib_graph = False for n in calibration_graph_def.node: if n.op == "TRTEngineOp": @@ -190,7 +355,7 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False): return None graph_str = calibration_graph_def.SerializeToString() out = calib_convert(graph_str, is_dynamic_op) - status = to_string(out[0]) + status = _to_string(out[0]) output_graph_def_string = out[1] del graph_str # Save some memory if len(status) < 2: diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py new file mode 100644 index 0000000000..118a6680fd --- /dev/null +++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py @@ -0,0 +1,293 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to test TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.tensorrt.python import trt_convert +# pylint: disable=unused-import +from tensorflow.contrib.tensorrt.python.ops import trt_engine_op +# pylint: enable=unused-import +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import graph_util +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.saved_model import builder +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.saved_model import utils +from tensorflow.python.tools import saved_model_utils + + +class TrtConvertTest(test_util.TensorFlowTestCase): + """Class to test Tensorflow-TensorRT integration python API.""" + + def testTensorrtRewriterConfig(self): + """Test case for trt_convert.tensorrt_rewriter_config().""" + rewriter_cfg = trt_convert.tensorrt_rewriter_config( + max_batch_size=128, + max_workspace_size_bytes=1234, + precision_mode="INT8", + minimum_segment_size=10, + is_dynamic_op=True, + maximum_cached_engines=2, + cached_engine_batch_sizes=[1, 128]) + trt_optimizer = None + for optimizer in rewriter_cfg.custom_optimizers: + if optimizer.name == "TensorRTOptimizer": + self.assertTrue(trt_optimizer is None) + trt_optimizer = optimizer + self.assertTrue(trt_optimizer is not None) + for key in [ + "minimum_segment_size", "max_batch_size", "is_dynamic_op", + "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines", + "cached_engine_batches" + ]: + self.assertTrue(key in trt_optimizer.parameter_map) + self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i) + self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i) + self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b) + self.assertEqual(1234, + trt_optimizer.parameter_map["max_workspace_size_bytes"].i) + self.assertEqual( + trt_convert._to_bytes("INT8"), + trt_optimizer.parameter_map["precision_mode"].s) + self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i) + self.assertEqual( + [1, 128], + trt_optimizer.parameter_map["cached_engine_batches"].list.i) + + def _GetConfigProto(self): + """Get ConfigProto for session creation.""" + config = config_pb2.ConfigProto( + gpu_options=config_pb2.GPUOptions(allow_growth=True)) + return config + + def _GetGraph(self): + """Get the graph for testing.""" + g = ops.Graph() + with g.as_default(): + with g.device("/GPU:0"): + inp = array_ops.placeholder( + dtype=dtypes.float32, shape=[None, 1, 1], name="input") + var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1") + add = inp + var.value() + mul = inp * add + add = mul + add + out = array_ops.identity(add, name="output") + return g, var, inp, out + + def _GetGraphDef(self): + """Get the graph def for testing.""" + g, var, _, _ = self._GetGraph() + with self.test_session(graph=g, config=self._GetConfigProto()) as sess: + sess.run(var.initializer) + graph_def = graph_util.convert_variables_to_constants( + sess, g.as_graph_def(add_shapes=True), ["output"]) + node_name_to_op = {node.name: node.op for node in graph_def.node} + self.assertEqual({ + "v1": "Const", + "v1/read": "Identity", + "input": "Placeholder", + "add": "Add", + "mul": "Mul", + "add_1": "Add", + "output": "Identity" + }, node_name_to_op) + return graph_def + + def _WriteInputSavedModel(self, input_saved_model_dir): + """Write the saved model as an input for testing.""" + g, var, inp, out = self._GetGraph() + signature_def = signature_def_utils.build_signature_def( + inputs={"myinput": utils.build_tensor_info(inp)}, + outputs={"myoutput": utils.build_tensor_info(out)}, + method_name=signature_constants.PREDICT_METHOD_NAME) + saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir) + with self.test_session(graph=g, config=self._GetConfigProto()) as sess: + sess.run(var.initializer) + saved_model_builder.add_meta_graph_and_variables( + sess, [tag_constants.SERVING], + signature_def_map={"mypredict": signature_def}) + saved_model_builder.save() + + def _TestCreateInferenceGraph(self, + input_saved_model_dir=None, + output_saved_model_dir=None): + """General method to test trt_convert.create_inference_graph().""" + input_graph_def = None if input_saved_model_dir else self._GetGraphDef() + output_graph_def = trt_convert.create_inference_graph( + input_graph_def, ["output"], + input_saved_model_dir=input_saved_model_dir, + output_saved_model_dir=output_saved_model_dir, + session_config=self._GetConfigProto()) + graph_defs_to_verify = [output_graph_def] + if output_saved_model_dir is not None: + saved_model_graph_def = saved_model_utils.get_meta_graph_def( + output_saved_model_dir, tag_constants.SERVING).graph_def + self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef)) + graph_defs_to_verify.append(saved_model_graph_def) + + for graph_def in graph_defs_to_verify: + node_name_to_op = {node.name: node.op for node in graph_def.node} + self.assertEqual({ + "input": "Placeholder", + "my_trt_op_0": "TRTEngineOp", + "output": "Identity" + }, node_name_to_op) + + def testCreateInferenceGraph_BasicConversion(self): + """Test case for trt_convert.create_inference_graph().""" + if not trt_convert.is_tensorrt_enabled(): + return + + # Use GraphDef as input. + self._TestCreateInferenceGraph() + + # Use SavedModel as input. + tmp_dir = self.get_temp_dir() + input_saved_model_dir = os.path.join(tmp_dir, "in_dir1") + output_saved_model_dir = os.path.join(tmp_dir, "out_dir1") + self._WriteInputSavedModel(input_saved_model_dir) + self._TestCreateInferenceGraph(input_saved_model_dir, + output_saved_model_dir) + + def _TestRun(self, sess, batch_size, expect_engine_is_run): + trt_convert.clear_test_values("") + result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size}) + self.assertAllEqual([[[4.0]]] * batch_size, result) + execute_engine_test_value = ("done" if expect_engine_is_run else "") + execute_native_segment_test_value = ("" if expect_engine_is_run else "done") + self.assertEqual(execute_engine_test_value, + trt_convert.get_test_value("my_trt_op_0:ExecuteTrtEngine")) + self.assertEqual( + execute_native_segment_test_value, + trt_convert.get_test_value("my_trt_op_0:ExecuteNativeSegment")) + + def testCreateInferenceGraph_MinimumSegmentSize(self): + if not trt_convert.is_tensorrt_enabled(): + return + output_graph_def = trt_convert.create_inference_graph( + self._GetGraphDef(), ["output"], + minimum_segment_size=5, + is_dynamic_op=False) + node_name_to_op = {node.name: node.op for node in output_graph_def.node} + self.assertEqual({ + "v1/read": "Const", + "input": "Placeholder", + "add": "Add", + "mul": "Mul", + "add_1": "Add", + "output": "Identity" + }, node_name_to_op) + + def testCreateInferenceGraph_DynamicOp(self): + if not trt_convert.is_tensorrt_enabled(): + return + trt_convert.enable_test_value() + + tmp_dir = self.get_temp_dir() + input_saved_model_dir = os.path.join(tmp_dir, "in_dir2") + output_saved_model_dir = os.path.join(tmp_dir, "out_dir2") + self._WriteInputSavedModel(input_saved_model_dir) + output_graph_def = trt_convert.create_inference_graph( + None, + None, + is_dynamic_op=True, + maximum_cached_engines=2, + input_saved_model_dir=input_saved_model_dir, + output_saved_model_dir=output_saved_model_dir, + session_config=self._GetConfigProto()) + + # Test the output GraphDef. + with ops.Graph().as_default(): + importer.import_graph_def(output_graph_def, name="") + with self.test_session(config=self._GetConfigProto()) as sess: + # Run with batch size 1, a new engine is created and cached. + self._TestRun(sess, 1, True) + # Run with batch size 2, a new engine is created and cached. + self._TestRun(sess, 2, True) + # Run with batch size 3, since the number of cached engines has reached + # the max, it should fall back to TF function. + self._TestRun(sess, 3, False) + + # Test the output SavedModel + with ops.Graph().as_default(): + with self.test_session(config=self._GetConfigProto()) as sess: + loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) + # Run with batch size 1, a new engine is created and cached. + self._TestRun(sess, 1, True) + # Run with batch size 2, a new engine is created and cached. + self._TestRun(sess, 2, True) + # Run with batch size 3, since the number of cached engines has reached + # the max, it should fall back to TF function. + self._TestRun(sess, 3, False) + + def testCreateInferenceGraph_StaticOp(self): + if not trt_convert.is_tensorrt_enabled(): + return + trt_convert.enable_test_value() + + tmp_dir = self.get_temp_dir() + input_saved_model_dir = os.path.join(tmp_dir, "in_dir3") + output_saved_model_dir = os.path.join(tmp_dir, "out_dir3") + self._WriteInputSavedModel(input_saved_model_dir) + output_graph_def = trt_convert.create_inference_graph( + None, + None, + max_batch_size=1, + is_dynamic_op=False, + maximum_cached_engines=2, # This is noop, added just for testing. + input_saved_model_dir=input_saved_model_dir, + output_saved_model_dir=output_saved_model_dir, + session_config=self._GetConfigProto()) + + # Test the output GraphDef. + with ops.Graph().as_default(): + importer.import_graph_def(output_graph_def, name="") + with self.test_session(config=self._GetConfigProto()) as sess: + # Run with batch size 1, the default engine embedded in the graphdef + # will be used. + self._TestRun(sess, 1, True) + # Run with batch size 2, which exceed the max_batch_size, it should fall + # back to TF function. + self._TestRun(sess, 2, False) + + # Test the output SavedModel + with ops.Graph().as_default(): + with self.test_session(config=self._GetConfigProto()) as sess: + loader.load(sess, [tag_constants.SERVING], output_saved_model_dir) + # Run with batch size 1, the default engine embedded in the graphdef + # will be used. + self._TestRun(sess, 1, True) + # Run with batch size 2, which exceed the max_batch_size, it should fall + # back to TF function. + self._TestRun(sess, 2, False) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 090aa8bdb0..d26f260086 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -191,7 +191,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=[]) + cached_engine_batch_sizes=[]) o1 = run_graph(orig_graph, dummy_input) o2 = run_graph(trt_graph, dummy_input) o3 = run_graph(trt_graph, dummy_input) @@ -206,7 +206,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=[]) + cached_engine_batch_sizes=[]) int8_calib_gdef = trt.create_inference_graph( input_graph_def=orig_graph, outputs=["output"], @@ -216,7 +216,7 @@ def user(multi_engine, minimum_segment_size=2, # minimum number of nodes in an engine is_dynamic_op=False, maximum_cached_engines=1, - cached_engine_batches=[]) + cached_engine_batch_sizes=[]) o4 = run_graph(fp16_graph, dummy_input) _ = run_calibration(int8_calib_gdef, dummy_input) int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef) diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 65ca21cf37..fc647e4eb9 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -30,7 +30,6 @@ from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer @@ -50,7 +49,7 @@ RunParams = namedtuple( ConversionParams = namedtuple("ConversionParams", [ "max_batch_size", "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", - "cached_engine_batches" + "cached_engine_batch_sizes" ]) PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -139,7 +138,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): minimum_segment_size=2, is_dynamic_op=run_params.dynamic_engine, maximum_cached_engines=1, - cached_engine_batches=None) + cached_engine_batch_sizes=None) def ShouldRunTest(self, run_params): """Whether to run the test.""" @@ -201,23 +200,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: - rewriter_cfg = rewriter_config_pb2.RewriterConfig() - rewriter_cfg.optimizers.extend(["constfold", "layout"]) - custom_op = rewriter_cfg.custom_optimizers.add() - custom_op.name = "TensorRTOptimizer" trt_params = self.GetConversionParams(run_params) - custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size - custom_op.parameter_map["max_workspace_size_bytes"].i = ( - trt_params.max_workspace_size_bytes) - custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode - custom_op.parameter_map["minimum_segment_size"].i = ( - trt_params.minimum_segment_size) - custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op - custom_op.parameter_map["maximum_cached_engines"].i = ( - trt_params.maximum_cached_engines) - if trt_params.cached_engine_batches: - custom_op.parameter_map["cached_engine_batches"].list.i.extend( - trt_params.cached_engine_batches) + rewriter_cfg = trt_convert.tensorrt_rewriter_config( + trt_params.max_batch_size, trt_params.max_workspace_size_bytes, + trt_params.precision_mode, trt_params.minimum_segment_size, + trt_params.is_dynamic_op, trt_params.maximum_cached_engines, + trt_params.cached_engine_batch_sizes) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: @@ -308,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): minimum_segment_size=trt_params.minimum_segment_size, is_dynamic_op=trt_params.is_dynamic_op, maximum_cached_engines=trt_params.maximum_cached_engines, - cached_engine_batches=trt_params.cached_engine_batches) + cached_engine_batch_sizes=trt_params.cached_engine_batch_sizes) def _WriteGraph(self, run_params, gdef, graph_state): if graph_state == GraphState.ORIGINAL: |