aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py')
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py28
1 files changed, 8 insertions, 20 deletions
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 65ca21cf37..fc647e4eb9 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -30,7 +30,6 @@ from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.contrib.tensorrt.python.ops import trt_engine_op
# pylint: enable=unused-import
from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
@@ -50,7 +49,7 @@ RunParams = namedtuple(
ConversionParams = namedtuple("ConversionParams", [
"max_batch_size", "max_workspace_size_bytes", "precision_mode",
"minimum_segment_size", "is_dynamic_op", "maximum_cached_engines",
- "cached_engine_batches"
+ "cached_engine_batch_sizes"
])
PRECISION_MODES = ["FP32", "FP16", "INT8"]
@@ -139,7 +138,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
minimum_segment_size=2,
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,
- cached_engine_batches=None)
+ cached_engine_batch_sizes=None)
def ShouldRunTest(self, run_params):
"""Whether to run the test."""
@@ -201,23 +200,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
def _GetConfigProto(self, run_params, graph_state):
"""Get config proto based on specific settings."""
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
- rewriter_cfg = rewriter_config_pb2.RewriterConfig()
- rewriter_cfg.optimizers.extend(["constfold", "layout"])
- custom_op = rewriter_cfg.custom_optimizers.add()
- custom_op.name = "TensorRTOptimizer"
trt_params = self.GetConversionParams(run_params)
- custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size
- custom_op.parameter_map["max_workspace_size_bytes"].i = (
- trt_params.max_workspace_size_bytes)
- custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode
- custom_op.parameter_map["minimum_segment_size"].i = (
- trt_params.minimum_segment_size)
- custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op
- custom_op.parameter_map["maximum_cached_engines"].i = (
- trt_params.maximum_cached_engines)
- if trt_params.cached_engine_batches:
- custom_op.parameter_map["cached_engine_batches"].list.i.extend(
- trt_params.cached_engine_batches)
+ rewriter_cfg = trt_convert.tensorrt_rewriter_config(
+ trt_params.max_batch_size, trt_params.max_workspace_size_bytes,
+ trt_params.precision_mode, trt_params.minimum_segment_size,
+ trt_params.is_dynamic_op, trt_params.maximum_cached_engines,
+ trt_params.cached_engine_batch_sizes)
graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg)
else:
@@ -308,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
minimum_segment_size=trt_params.minimum_segment_size,
is_dynamic_op=trt_params.is_dynamic_op,
maximum_cached_engines=trt_params.maximum_cached_engines,
- cached_engine_batches=trt_params.cached_engine_batches)
+ cached_engine_batch_sizes=trt_params.cached_engine_batch_sizes)
def _WriteGraph(self, run_params, gdef, graph_state):
if graph_state == GraphState.ORIGINAL: