aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar Guangda Lai <laigd@google.com>2018-09-12 19:05:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 19:08:27 -0700
commit28ede9ed7caee0ce2731d95cc0eb9aff7f360105 (patch)
treeb71838ce0eac5bd833bd4b2555cec9428ca3c87b /tensorflow/contrib/tensorrt
parent30b711b07570b12c8880532aede428503c35e310 (diff)
Add SavedModel support to TensorRT's create_inference_graph() API.
PiperOrigin-RevId: 212743550
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD31
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert.py319
-rw-r--r--tensorflow/contrib/tensorrt/python/trt_convert_test.py293
-rw-r--r--tensorflow/contrib/tensorrt/test/test_tftrt.py6
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py28
5 files changed, 577 insertions, 100 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 122a67a407..9e8979bce4 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -19,6 +19,7 @@ load(
"tf_gen_op_libs",
"tf_gen_op_wrapper_py",
)
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
@@ -181,7 +182,12 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":wrap_conversion",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python:session",
"//tensorflow/python:tf_optimizer",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
],
)
@@ -410,6 +416,31 @@ py_library(
],
)
+cuda_py_test(
+ name = "trt_convert_test",
+ srcs = ["python/trt_convert_test.py"],
+ additional_deps = [
+ ":trt_convert_py",
+ ":trt_ops_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow/python/saved_model:utils",
+ "//tensorflow/python/tools:freeze_graph_lib",
+ "//tensorflow/python/tools:saved_model_utils",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_windows",
+ "nomac",
+ ],
+)
+
cuda_py_tests(
name = "tf_trt_integration_test",
srcs = [
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 4116f2fe30..369e73b5a6 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-# pylint: disable=unused-import,line-too-long
import six as _six
+# pylint: disable=unused-import,line-too-long
from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
@@ -28,55 +28,179 @@ from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_vers
from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
+# pylint: enable=unused-import,line-too-long
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.grappler import tf_optimizer
from tensorflow.python.platform import tf_logging
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader_impl
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
-# pylint: enable=unused-import,line-too-long
+
+if _six.PY2:
+ _to_bytes = lambda s: s
+ _to_string = lambda s: s
+else:
+ _to_bytes = lambda s: s.encode("utf-8", errors="surrogateescape")
+ _to_string = lambda s: s.decode("utf-8")
+
+
+class TrtPrecisionMode(object):
+ FP32 = "FP32"
+ FP16 = "FP16"
+ INT8 = "INT8"
+
+ @staticmethod
+ def supported_precision_modes():
+ return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
+
+
+def tensorrt_rewriter_config(max_batch_size=1,
+ max_workspace_size_bytes=2 << 20,
+ precision_mode=TrtPrecisionMode.FP32,
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batch_sizes=None):
+ """Returns a RewriterConfig proto for TRT transformation.
+
+ Args:
+ max_batch_size: max size for the input batch
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
+ minimum_segment_size: the minimum number of nodes required for a subgraph to
+ be replaced by TRTEngineOp.
+ is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
+ network and engine at run time.
+ maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batch_sizes: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be smaller than maximum_cached_engines, and the dynamic TRT op will
+ use this list to determine the batch sizes of the cached engines, instead
+ of making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+
+ Returns:
+ A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
+
+ Raises:
+ TypeError: if the provided precision mode is invalid.
+ ValueError: if len(cached_engine_batch_sizes) exceed maximum_cached_engines.
+ """
+ if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes():
+ raise ValueError(("precision mode '{}' is not supported."
+ "It should be one of {}").format(
+ precision_mode,
+ TrtPrecisionMode.supported_precision_modes))
+
+ rewriter_cfg = rewriter_config_pb2.RewriterConfig()
+ rewriter_cfg.optimizers.extend(["constfold", "layout"])
+ optimizer = rewriter_cfg.custom_optimizers.add()
+ optimizer.name = "TensorRTOptimizer"
+ optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
+ optimizer.parameter_map["max_batch_size"].i = max_batch_size
+ optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ optimizer.parameter_map[
+ "max_workspace_size_bytes"].i = max_workspace_size_bytes
+ optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
+ optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
+ if cached_engine_batch_sizes:
+ if not isinstance(cached_engine_batch_sizes, list):
+ raise TypeError("cached_engine_batch_sizes should be a list.")
+ if len(cached_engine_batch_sizes) > maximum_cached_engines:
+ raise ValueError("cached_engine_batch_sizes should not contain more than "
+ "maximum_cached_engines items.")
+ optimizer.parameter_map["cached_engine_batches"].list.i.extend(
+ cached_engine_batch_sizes)
+ return rewriter_cfg
def create_inference_graph(input_graph_def,
outputs,
max_batch_size=1,
max_workspace_size_bytes=2 << 20,
- precision_mode="FP32",
+ precision_mode=TrtPrecisionMode.FP32,
minimum_segment_size=3,
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=None):
+ cached_engine_batch_sizes=None,
+ input_saved_model_dir=None,
+ input_saved_model_tags=None,
+ output_saved_model_dir=None,
+ session_config=None):
"""Python wrapper for the TRT transformation.
Args:
- input_graph_def: GraphDef object containing a model to be transformed.
- outputs: list of tensors or node names for the model outputs.
- max_batch_size: max size for the input batch
- max_workspace_size_bytes: parameter to control memory allocation (in Bytes)
- precision_mode: one of 'FP32', 'FP16' and 'INT8'
+ input_graph_def: a GraphDef object containing a model to be transformed. If
+ set to None, the graph will be read from the SavedModel loaded from
+ input_saved_model_dir.
+ outputs: list of tensors or node names for the model outputs. Only used when
+ input_graph_def is not None.
+ max_batch_size: max size for the input batch.
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
minimum_segment_size: the minimum number of nodes required for a subgraph to
be replaced by TRTEngineOp.
is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
network and engine at run time.
maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
- cached_engine_batches: batch sizes used to pre-create cached engines.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batch_sizes: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be smaller than maximum_cached_engines, and the dynamic TRT op will
+ use this list to determine the batch sizes of the cached engines, instead
+ of making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+ input_saved_model_dir: the directory to load the SavedModel which contains
+ the input graph to transforms. Used only when input_graph_def is None.
+ input_saved_model_tags: list of tags to load the SavedModel.
+ output_saved_model_dir: if not None, construct a SavedModel using the
+ returned GraphDef and save it to the specified directory. This option only
+ works when the input graph is loaded from a SavedModel, i.e. when
+ input_saved_model_dir is specified and input_graph_def is None.
+ session_config: the ConfigProto used to create a Session. If not specified,
+ a default ConfigProto will be used.
Returns:
- New GraphDef with TRTEngineOps placed in graph replacing subgraphs.
+ A GraphDef transformed from input_graph_def (or the SavedModel graph def
+ loaded from input_saved_model_dir, if input_graph_def is not present), where
+ all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
+ function is added for each of the subgraphs.
+
+ If is_dynamic_op is True, each TRTEngineOp will contain a serialized
+ subgraph GraphDef, which will be converted to a TRT engine at execution time
+ and the TRT engine will be cached for future usage. A new TRT engine will be
+ created each time when none of the cached engines match the input shapes. If
+ it fails to execute the TRT engine or the number of cached engines reaches
+ maximum_cached_engines, the op will fall back to call the corresponding TF
+ function.
+
+ If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
+ engine created from the corresponding subgraph. No more engines will be
+ created on the fly, and the op will fall back to call the corresponding TF
+ function when it fails to execute the engine.
Raises:
- ValueError: if the provided precision mode is invalid.
- RuntimeError: if the returned status message is malformed.
+ ValueError: if the combination of the parameters is invalid.
+ RuntimeError: if the TensorRT library version is incompatible.
"""
- supported_precision_modes = {"FP32": 0, "FP16": 1, "INT8": 2}
- if precision_mode.upper() not in supported_precision_modes:
- raise ValueError(("precision mode '{}' is not supported."
- "It should be one of {}").format(
- precision_mode, "{'FP32', 'FP16', 'INT8'}"))
- mode = supported_precision_modes[precision_mode.upper()]
compiled_version = get_linked_tensorrt_version()
loaded_version = get_loaded_tensorrt_version()
version_mismatch = False
@@ -101,61 +225,111 @@ def create_inference_graph(input_graph_def,
tf_logging.info("Running against TensorRT version %s" % ".".join(
[str(x) for x in loaded_version]))
- def py2bytes(inp):
- return inp
+ if session_config is None:
+ session_config = config_pb2.ConfigProto()
+
+ if input_saved_model_tags is None:
+ input_saved_model_tags = [tag_constants.SERVING]
+ saved_model_loader = None
+ grappler_meta_graph_def = None
- def py3bytes(inp):
- return inp.encode("utf-8", errors="surrogateescape")
+ if input_graph_def is None:
+ # Read from SavedModel and freeze the graph if necessary.
+ if input_saved_model_dir is None:
+ raise ValueError("input_graph_def and input_saved_model_dir cannot be "
+ "both None")
+ with ops.Graph().as_default():
+ with session.Session(config=session_config) as sess:
+ saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir)
+ input_meta_graph_def = saved_model_loader.load(sess,
+ input_saved_model_tags)
+ output_node_names = set()
- def py2string(inp):
- return inp
+ def _gather_names(tensor_info):
+ """Get the node names from a TensorInfo."""
+ return set(
+ [tensor_info[key].name.split(":")[0] for key in tensor_info])
- def py3string(inp):
- return inp.decode("utf-8")
+ # Get input and outputs from all SignatureDef.
+ for key in input_meta_graph_def.signature_def:
+ signature_def = input_meta_graph_def.signature_def[key]
+ output_node_names.update(_gather_names(signature_def.inputs))
+ output_node_names.update(_gather_names(signature_def.outputs))
- if _six.PY2:
- to_bytes = py2bytes
- to_string = py2string
+ # Freeze the variables in the SavedModel graph and copy the frozen
+ # graph over.
+ frozen_graph_def = graph_util.convert_variables_to_constants(
+ sess, sess.graph.as_graph_def(add_shapes=True),
+ list(output_node_names))
+ grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
+
+ # Copy the collections that are not variables.
+ for key in input_meta_graph_def.collection_def:
+ # TODO(laigd): currently we use the collection key to filter out
+ # collections that depend on variable ops, but this may miss some
+ # other user-defined collections. A better way would be to use
+ # CollectionDef::NodeList for the filtering.
+ if key not in [
+ "variables", "local_variables", "model_variables",
+ "trainable_variables", "train_op", "table_initializer"
+ ]:
+ grappler_meta_graph_def.collection_def[key].CopyFrom(
+ input_meta_graph_def.collection_def[key])
+
+ # Copy other information.
+ grappler_meta_graph_def.meta_info_def.CopyFrom(
+ input_meta_graph_def.meta_info_def)
+ for key in input_meta_graph_def.signature_def:
+ grappler_meta_graph_def.signature_def[key].CopyFrom(
+ input_meta_graph_def.signature_def[key])
+ # TODO(laigd): maybe add back AssetFileDef.
else:
- to_bytes = py3bytes
- to_string = py3string
-
- # Create MetaGraphDef
- graph = ops.Graph()
- with graph.as_default():
- importer.import_graph_def(input_graph_def, name="")
- meta_graph = saver.export_meta_graph(
- graph_def=graph.as_graph_def(), graph=graph)
- if outputs:
- output_collection = meta_graph_pb2.CollectionDef()
- output_list = output_collection.node_list.value
- for i in outputs:
- if isinstance(i, ops.Tensor):
- output_list.append(to_bytes(i.name))
- else:
- output_list.append(to_bytes(i))
- meta_graph.collection_def["train_op"].CopyFrom(output_collection)
+ if output_saved_model_dir is not None:
+ raise ValueError("output_saved_model_dir cannot be set when "
+ "input_graph_def is set")
+ # Create MetaGraphDef from input graph.
+ graph = ops.Graph()
+ with graph.as_default():
+ importer.import_graph_def(input_graph_def, name="")
+ grappler_meta_graph_def = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
+ if outputs:
+ output_collection = meta_graph_pb2.CollectionDef()
+ output_list = output_collection.node_list.value
+ for i in outputs:
+ if isinstance(i, ops.Tensor):
+ output_list.append(_to_bytes(i.name))
+ else:
+ output_list.append(_to_bytes(i))
+ # TODO(laigd): use another key as the outputs are really not train_op.
+ grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
+ output_collection)
# Create RewriterConfig.
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- optimizer = rewriter_cfg.custom_optimizers.add()
- optimizer.name = "TensorRTOptimizer"
- optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
- optimizer.parameter_map["max_batch_size"].i = max_batch_size
- optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
- optimizer.parameter_map[
- "max_workspace_size_bytes"].i = max_workspace_size_bytes
- optimizer.parameter_map["precision_mode"].s = to_bytes(precision_mode)
- optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
- if cached_engine_batches:
- if not isinstance(cached_engine_batches, list):
- raise TypeError("cached_engine_batches should be a list.")
- optimizer.parameter_map["cached_engine_batches"].list.i.extend(
- cached_engine_batches)
+ rewriter_cfg = tensorrt_rewriter_config(
+ max_batch_size, max_workspace_size_bytes, precision_mode,
+ minimum_segment_size, is_dynamic_op, maximum_cached_engines,
+ cached_engine_batch_sizes)
+
+ # Run Grappler.
+ transformed_graph_def = tf_optimizer.OptimizeGraph(
+ rewriter_cfg, grappler_meta_graph_def, graph_id=b"tf_graph")
- return tf_optimizer.OptimizeGraph(
- rewriter_cfg, meta_graph, graph_id=b"tf_graph")
+ # Optionally write the transformed graphdef as SavedModel.
+ if output_saved_model_dir is not None:
+ saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
+ with ops.Graph().as_default():
+ importer.import_graph_def(transformed_graph_def, name="")
+ with session.Session(config=session_config) as sess:
+ saved_model_builder.add_meta_graph_and_variables(
+ sess,
+ input_saved_model_tags,
+ signature_def_map=grappler_meta_graph_def.signature_def)
+ # Ignore other meta graphs from the input SavedModel.
+ saved_model_builder.save()
+
+ return transformed_graph_def
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
@@ -164,22 +338,13 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
Args:
calibration_graph_def: the calibration GraphDef object with calibration data
is_dynamic_op: whether to create dynamic static engines from calibration
+
Returns:
New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
Raises:
RuntimeError: if the returned status message is malformed.
"""
- def py2string(inp):
- return inp
-
- def py3string(inp):
- return inp.decode("utf-8")
-
- if _six.PY2:
- to_string = py2string
- else:
- to_string = py3string
is_calib_graph = False
for n in calibration_graph_def.node:
if n.op == "TRTEngineOp":
@@ -190,7 +355,7 @@ def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
return None
graph_str = calibration_graph_def.SerializeToString()
out = calib_convert(graph_str, is_dynamic_op)
- status = to_string(out[0])
+ status = _to_string(out[0])
output_graph_def_string = out[1]
del graph_str # Save some memory
if len(status) < 2:
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
new file mode 100644
index 0000000000..118a6680fd
--- /dev/null
+++ b/tensorflow/contrib/tensorrt/python/trt_convert_test.py
@@ -0,0 +1,293 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utilities to test TF-TensorRT integration."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.tensorrt.python import trt_convert
+# pylint: disable=unused-import
+from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
+# pylint: enable=unused-import
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.saved_model import utils
+from tensorflow.python.tools import saved_model_utils
+
+
+class TrtConvertTest(test_util.TensorFlowTestCase):
+ """Class to test Tensorflow-TensorRT integration python API."""
+
+ def testTensorrtRewriterConfig(self):
+ """Test case for trt_convert.tensorrt_rewriter_config()."""
+ rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+ max_batch_size=128,
+ max_workspace_size_bytes=1234,
+ precision_mode="INT8",
+ minimum_segment_size=10,
+ is_dynamic_op=True,
+ maximum_cached_engines=2,
+ cached_engine_batch_sizes=[1, 128])
+ trt_optimizer = None
+ for optimizer in rewriter_cfg.custom_optimizers:
+ if optimizer.name == "TensorRTOptimizer":
+ self.assertTrue(trt_optimizer is None)
+ trt_optimizer = optimizer
+ self.assertTrue(trt_optimizer is not None)
+ for key in [
+ "minimum_segment_size", "max_batch_size", "is_dynamic_op",
+ "max_workspace_size_bytes", "precision_mode", "maximum_cached_engines",
+ "cached_engine_batches"
+ ]:
+ self.assertTrue(key in trt_optimizer.parameter_map)
+ self.assertEqual(10, trt_optimizer.parameter_map["minimum_segment_size"].i)
+ self.assertEqual(128, trt_optimizer.parameter_map["max_batch_size"].i)
+ self.assertEqual(True, trt_optimizer.parameter_map["is_dynamic_op"].b)
+ self.assertEqual(1234,
+ trt_optimizer.parameter_map["max_workspace_size_bytes"].i)
+ self.assertEqual(
+ trt_convert._to_bytes("INT8"),
+ trt_optimizer.parameter_map["precision_mode"].s)
+ self.assertEqual(2, trt_optimizer.parameter_map["maximum_cached_engines"].i)
+ self.assertEqual(
+ [1, 128],
+ trt_optimizer.parameter_map["cached_engine_batches"].list.i)
+
+ def _GetConfigProto(self):
+ """Get ConfigProto for session creation."""
+ config = config_pb2.ConfigProto(
+ gpu_options=config_pb2.GPUOptions(allow_growth=True))
+ return config
+
+ def _GetGraph(self):
+ """Get the graph for testing."""
+ g = ops.Graph()
+ with g.as_default():
+ with g.device("/GPU:0"):
+ inp = array_ops.placeholder(
+ dtype=dtypes.float32, shape=[None, 1, 1], name="input")
+ var = variables.Variable([[[1.0]]], dtype=dtypes.float32, name="v1")
+ add = inp + var.value()
+ mul = inp * add
+ add = mul + add
+ out = array_ops.identity(add, name="output")
+ return g, var, inp, out
+
+ def _GetGraphDef(self):
+ """Get the graph def for testing."""
+ g, var, _, _ = self._GetGraph()
+ with self.test_session(graph=g, config=self._GetConfigProto()) as sess:
+ sess.run(var.initializer)
+ graph_def = graph_util.convert_variables_to_constants(
+ sess, g.as_graph_def(add_shapes=True), ["output"])
+ node_name_to_op = {node.name: node.op for node in graph_def.node}
+ self.assertEqual({
+ "v1": "Const",
+ "v1/read": "Identity",
+ "input": "Placeholder",
+ "add": "Add",
+ "mul": "Mul",
+ "add_1": "Add",
+ "output": "Identity"
+ }, node_name_to_op)
+ return graph_def
+
+ def _WriteInputSavedModel(self, input_saved_model_dir):
+ """Write the saved model as an input for testing."""
+ g, var, inp, out = self._GetGraph()
+ signature_def = signature_def_utils.build_signature_def(
+ inputs={"myinput": utils.build_tensor_info(inp)},
+ outputs={"myoutput": utils.build_tensor_info(out)},
+ method_name=signature_constants.PREDICT_METHOD_NAME)
+ saved_model_builder = builder.SavedModelBuilder(input_saved_model_dir)
+ with self.test_session(graph=g, config=self._GetConfigProto()) as sess:
+ sess.run(var.initializer)
+ saved_model_builder.add_meta_graph_and_variables(
+ sess, [tag_constants.SERVING],
+ signature_def_map={"mypredict": signature_def})
+ saved_model_builder.save()
+
+ def _TestCreateInferenceGraph(self,
+ input_saved_model_dir=None,
+ output_saved_model_dir=None):
+ """General method to test trt_convert.create_inference_graph()."""
+ input_graph_def = None if input_saved_model_dir else self._GetGraphDef()
+ output_graph_def = trt_convert.create_inference_graph(
+ input_graph_def, ["output"],
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+ graph_defs_to_verify = [output_graph_def]
+ if output_saved_model_dir is not None:
+ saved_model_graph_def = saved_model_utils.get_meta_graph_def(
+ output_saved_model_dir, tag_constants.SERVING).graph_def
+ self.assertTrue(isinstance(saved_model_graph_def, graph_pb2.GraphDef))
+ graph_defs_to_verify.append(saved_model_graph_def)
+
+ for graph_def in graph_defs_to_verify:
+ node_name_to_op = {node.name: node.op for node in graph_def.node}
+ self.assertEqual({
+ "input": "Placeholder",
+ "my_trt_op_0": "TRTEngineOp",
+ "output": "Identity"
+ }, node_name_to_op)
+
+ def testCreateInferenceGraph_BasicConversion(self):
+ """Test case for trt_convert.create_inference_graph()."""
+ if not trt_convert.is_tensorrt_enabled():
+ return
+
+ # Use GraphDef as input.
+ self._TestCreateInferenceGraph()
+
+ # Use SavedModel as input.
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir1")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir1")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ self._TestCreateInferenceGraph(input_saved_model_dir,
+ output_saved_model_dir)
+
+ def _TestRun(self, sess, batch_size, expect_engine_is_run):
+ trt_convert.clear_test_values("")
+ result = sess.run("output:0", feed_dict={"input:0": [[[1.0]]] * batch_size})
+ self.assertAllEqual([[[4.0]]] * batch_size, result)
+ execute_engine_test_value = ("done" if expect_engine_is_run else "")
+ execute_native_segment_test_value = ("" if expect_engine_is_run else "done")
+ self.assertEqual(execute_engine_test_value,
+ trt_convert.get_test_value("my_trt_op_0:ExecuteTrtEngine"))
+ self.assertEqual(
+ execute_native_segment_test_value,
+ trt_convert.get_test_value("my_trt_op_0:ExecuteNativeSegment"))
+
+ def testCreateInferenceGraph_MinimumSegmentSize(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ output_graph_def = trt_convert.create_inference_graph(
+ self._GetGraphDef(), ["output"],
+ minimum_segment_size=5,
+ is_dynamic_op=False)
+ node_name_to_op = {node.name: node.op for node in output_graph_def.node}
+ self.assertEqual({
+ "v1/read": "Const",
+ "input": "Placeholder",
+ "add": "Add",
+ "mul": "Mul",
+ "add_1": "Add",
+ "output": "Identity"
+ }, node_name_to_op)
+
+ def testCreateInferenceGraph_DynamicOp(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ trt_convert.enable_test_value()
+
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir2")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir2")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ output_graph_def = trt_convert.create_inference_graph(
+ None,
+ None,
+ is_dynamic_op=True,
+ maximum_cached_engines=2,
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+
+ # Test the output GraphDef.
+ with ops.Graph().as_default():
+ importer.import_graph_def(output_graph_def, name="")
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ # Run with batch size 1, a new engine is created and cached.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, a new engine is created and cached.
+ self._TestRun(sess, 2, True)
+ # Run with batch size 3, since the number of cached engines has reached
+ # the max, it should fall back to TF function.
+ self._TestRun(sess, 3, False)
+
+ # Test the output SavedModel
+ with ops.Graph().as_default():
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
+ # Run with batch size 1, a new engine is created and cached.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, a new engine is created and cached.
+ self._TestRun(sess, 2, True)
+ # Run with batch size 3, since the number of cached engines has reached
+ # the max, it should fall back to TF function.
+ self._TestRun(sess, 3, False)
+
+ def testCreateInferenceGraph_StaticOp(self):
+ if not trt_convert.is_tensorrt_enabled():
+ return
+ trt_convert.enable_test_value()
+
+ tmp_dir = self.get_temp_dir()
+ input_saved_model_dir = os.path.join(tmp_dir, "in_dir3")
+ output_saved_model_dir = os.path.join(tmp_dir, "out_dir3")
+ self._WriteInputSavedModel(input_saved_model_dir)
+ output_graph_def = trt_convert.create_inference_graph(
+ None,
+ None,
+ max_batch_size=1,
+ is_dynamic_op=False,
+ maximum_cached_engines=2, # This is noop, added just for testing.
+ input_saved_model_dir=input_saved_model_dir,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=self._GetConfigProto())
+
+ # Test the output GraphDef.
+ with ops.Graph().as_default():
+ importer.import_graph_def(output_graph_def, name="")
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ # Run with batch size 1, the default engine embedded in the graphdef
+ # will be used.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, which exceed the max_batch_size, it should fall
+ # back to TF function.
+ self._TestRun(sess, 2, False)
+
+ # Test the output SavedModel
+ with ops.Graph().as_default():
+ with self.test_session(config=self._GetConfigProto()) as sess:
+ loader.load(sess, [tag_constants.SERVING], output_saved_model_dir)
+ # Run with batch size 1, the default engine embedded in the graphdef
+ # will be used.
+ self._TestRun(sess, 1, True)
+ # Run with batch size 2, which exceed the max_batch_size, it should fall
+ # back to TF function.
+ self._TestRun(sess, 2, False)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
index 090aa8bdb0..d26f260086 100644
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py
@@ -191,7 +191,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
o1 = run_graph(orig_graph, dummy_input)
o2 = run_graph(trt_graph, dummy_input)
o3 = run_graph(trt_graph, dummy_input)
@@ -206,7 +206,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
int8_calib_gdef = trt.create_inference_graph(
input_graph_def=orig_graph,
outputs=["output"],
@@ -216,7 +216,7 @@ def user(multi_engine,
minimum_segment_size=2, # minimum number of nodes in an engine
is_dynamic_op=False,
maximum_cached_engines=1,
- cached_engine_batches=[])
+ cached_engine_batch_sizes=[])
o4 = run_graph(fp16_graph, dummy_input)
_ = run_calibration(int8_calib_gdef, dummy_input)
int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 65ca21cf37..fc647e4eb9 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -30,7 +30,6 @@ from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
@@ -50,7 +49,7 @@ RunParams = namedtuple(
ConversionParams = namedtuple("ConversionParams", [
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
- "cached_engine_batches"
+ "cached_engine_batch_sizes"
])
PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -139,7 +138,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
minimum_segment_size=2,
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,
- cached_engine_batches=None)
+ cached_engine_batch_sizes=None)
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
@@ -201,23 +200,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _GetConfigProto(self, run_params, graph_state):
"""Get config proto based on specific settings."""
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- custom_op = rewriter_cfg.custom_optimizers.add()
- custom_op.name = "TensorRTOptimizer"
trt_params = self.GetConversionParams(run_params)
- custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size
- custom_op.parameter_map["max_workspace_size_bytes"].i = (
- trt_params.max_workspace_size_bytes)
- custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode
- custom_op.parameter_map["minimum_segment_size"].i = (
- trt_params.minimum_segment_size)
- custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op
- custom_op.parameter_map["maximum_cached_engines"].i = (
- trt_params.maximum_cached_engines)
- if trt_params.cached_engine_batches:
- custom_op.parameter_map["cached_engine_batches"].list.i.extend(
- trt_params.cached_engine_batches)
+ rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+ trt_params.max_batch_size, trt_params.max_workspace_size_bytes,
+ trt_params.precision_mode, trt_params.minimum_segment_size,
+ trt_params.is_dynamic_op, trt_params.maximum_cached_engines,
+ trt_params.cached_engine_batch_sizes)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
@@ -308,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
minimum_segment_size=trt_params.minimum_segment_size,
is_dynamic_op=trt_params.is_dynamic_op,
maximum_cached_engines=trt_params.maximum_cached_engines,
- cached_engine_batches=trt_params.cached_engine_batches)
+ cached_engine_batch_sizes=trt_params.cached_engine_batch_sizes)
def _WriteGraph(self, run_params, gdef, graph_state):
if graph_state == GraphState.ORIGINAL: