diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py')
-rw-r--r-- | tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py | 28 |
1 files changed, 8 insertions, 20 deletions
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 65ca21cf37..fc647e4eb9 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -30,7 +30,6 @@ from tensorflow.contrib.tensorrt.python import trt_convert from tensorflow.contrib.tensorrt.python.ops import trt_engine_op # pylint: enable=unused-import from tensorflow.core.protobuf import config_pb2 -from tensorflow.core.protobuf import rewriter_config_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import graph_io from tensorflow.python.framework import importer @@ -50,7 +49,7 @@ RunParams = namedtuple( ConversionParams = namedtuple("ConversionParams", [ "max_batch_size", "max_workspace_size_bytes", "precision_mode", "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", - "cached_engine_batches" + "cached_engine_batch_sizes" ]) PRECISION_MODES = ["FP32", "FP16", "INT8"] @@ -139,7 +138,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): minimum_segment_size=2, is_dynamic_op=run_params.dynamic_engine, maximum_cached_engines=1, - cached_engine_batches=None) + cached_engine_batch_sizes=None) def ShouldRunTest(self, run_params): """Whether to run the test.""" @@ -201,23 +200,12 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: - rewriter_cfg = rewriter_config_pb2.RewriterConfig() - rewriter_cfg.optimizers.extend(["constfold", "layout"]) - custom_op = rewriter_cfg.custom_optimizers.add() - custom_op.name = "TensorRTOptimizer" trt_params = self.GetConversionParams(run_params) - custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size - custom_op.parameter_map["max_workspace_size_bytes"].i = ( - trt_params.max_workspace_size_bytes) - custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode - custom_op.parameter_map["minimum_segment_size"].i = ( - trt_params.minimum_segment_size) - custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op - custom_op.parameter_map["maximum_cached_engines"].i = ( - trt_params.maximum_cached_engines) - if trt_params.cached_engine_batches: - custom_op.parameter_map["cached_engine_batches"].list.i.extend( - trt_params.cached_engine_batches) + rewriter_cfg = trt_convert.tensorrt_rewriter_config( + trt_params.max_batch_size, trt_params.max_workspace_size_bytes, + trt_params.precision_mode, trt_params.minimum_segment_size, + trt_params.is_dynamic_op, trt_params.maximum_cached_engines, + trt_params.cached_engine_batch_sizes) graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: @@ -308,7 +296,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): minimum_segment_size=trt_params.minimum_segment_size, is_dynamic_op=trt_params.is_dynamic_op, maximum_cached_engines=trt_params.maximum_cached_engines, - cached_engine_batches=trt_params.cached_engine_batches) + cached_engine_batch_sizes=trt_params.cached_engine_batch_sizes) def _WriteGraph(self, run_params, gdef, graph_state): if graph_state == GraphState.ORIGINAL: |