diff options
16 files changed, 548 insertions, 246 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 5b54cb76b4..26236a0435 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -387,18 +387,19 @@ cuda_py_tests( name = "tf_trt_integration_test", srcs = [ "test/base_test.py", - # "test/batch_matmul_test.py", - # "test/biasadd_matmul_test.py", - # "test/binary_tensor_weight_broadcast_test.py", # Blocked by trt4 installation - # "test/concatenation_test.py", # Blocked by trt4 installation + "test/batch_matmul_test.py", + "test/biasadd_matmul_test.py", + "test/binary_tensor_weight_broadcast_test.py", + "test/concatenation_test.py", "test/const_broadcast_test.py", + "test/manual_test.py", + "test/memory_alignment_test.py", "test/multi_connection_neighbor_engine_test.py", "test/neighboring_engine_test.py", "test/rank_two_test.py", - # "test/unary_test.py", # Blocked by trt4 installation - # "test/vgg_block_nchw_test.py", - # "test/vgg_block_test.py", - "test/memory_alignment_test.py", + "test/unary_test.py", + "test/vgg_block_nchw_test.py", + "test/vgg_block_test.py", ], additional_deps = [ ":tf_trt_integration_test_base", diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/contrib/tensorrt/test/base_test.py index 8ea5a63735..8453807a50 100644 --- a/tensorflow/contrib/tensorrt/test/base_test.py +++ b/tensorflow/contrib/tensorrt/test/base_test.py @@ -67,14 +67,15 @@ class SimpleSingleEngineTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which - # breaks the connection check, fix it. - # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", - # "relu", "identity", "max_pool"] - expected_engines=["my_trt_op_0"], - expected_output_dims=(100, 6, 6, 6), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(100, 6, 6, 6)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["weights", "conv", "bias", "bias_add", + # "relu", "identity", "max_pool"] + return ["my_trt_op_0"] class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): @@ -120,15 +121,16 @@ class SimpleMultiEnginesTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which - # breaks the connection check, fix it. - # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", - # "add", "sub1"]; - # - my_trt_op_1 should have ["weights","conv", "div"] - expected_engines=["my_trt_op_0", "my_trt_op_1"], - expected_output_dims=(100, 12, 12, 6), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(100, 12, 12, 6)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + # TODO(aaroey): LayoutOptimizer adds additional nodes to the graph which + # breaks the connection check, fix it. + # - my_trt_op_0 should have ["mul", "sub", "div1", "mul1", "add1", + # "add", "sub1"]; + # - my_trt_op_1 should have ["weights","conv", "div"] + return ["my_trt_op_0", "my_trt_op_1"] class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): @@ -166,13 +168,14 @@ class PartiallyConvertedTestA(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines={ - # Only the first engine is built. - "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] - }, - expected_output_dims=tuple(input_dims), - allclose_atol=1.e-06, - allclose_rtol=1.e-06) + expected_output_dims=tuple(input_dims)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return { + # Only the first engine is built. + "my_trt_op_0": ["c0", "c1", "add0", "add1", "mul0", "mul1"] + } class PartiallyConvertedTestB(PartiallyConvertedTestA): @@ -184,13 +187,12 @@ class PartiallyConvertedTestB(PartiallyConvertedTestA): trt_convert.clear_test_values("") trt_convert.add_test_value("my_trt_op_0:CreateTRTNode", "fail") - def GetParams(self): - """Create a graph containing two segment.""" - return super(PartiallyConvertedTestB, self).GetParams()._replace( - expected_engines={ - # Only the second engine is built. - "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] - }) + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return { + # Only the second engine is built. + "my_trt_op_1": ["c2", "c3", "add2", "add3", "mul2", "mul3"] + } class ConstInputTest(trt_test.TfTrtIntegrationTestBase): @@ -226,13 +228,14 @@ class ConstInputTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines={ - "my_trt_op_0": ["add", "add1", "mul"], - "my_trt_op_1": ["add2", "add3", "mul1"] - }, - expected_output_dims=tuple(input_dims), - allclose_atol=1.e-06, - allclose_rtol=1.e-06) + expected_output_dims=tuple(input_dims)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return { + "my_trt_op_0": ["add", "add1", "mul"], + "my_trt_op_1": ["add2", "add3", "mul1"] + } class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): @@ -256,10 +259,11 @@ class ConstDataInputSingleEngineTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines={"my_trt_op_0": ["c", "add", "add1", "mul"]}, - expected_output_dims=tuple(input_dims), - allclose_atol=1.e-06, - allclose_rtol=1.e-06) + expected_output_dims=tuple(input_dims)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return {"my_trt_op_0": ["c", "add", "add1", "mul"]} class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): @@ -287,17 +291,18 @@ class ConstDataInputMultipleEnginesTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines={ - "my_trt_op_0": ["add2", "add3", "mul1"], - # Why segment ["add", "add1", "mul"] was assigned segment id 1 - # instead of 0: the parent node of this segment is actually const - # node 'c', but it's removed later since it's const output of the - # segment which is not allowed. - "my_trt_op_1": ["add", "add1", "mul"] - }, - expected_output_dims=tuple(input_dims), - allclose_atol=1.e-06, - allclose_rtol=1.e-06) + expected_output_dims=tuple(input_dims)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return { + "my_trt_op_0": ["add2", "add3", "mul1"], + # Why segment ["add", "add1", "mul"] was assigned segment id 1 + # instead of 0: the parent node of this segment is actually const + # node 'c', but it's removed later since it's const output of the + # segment which is not allowed. + "my_trt_op_1": ["add", "add1", "mul"] + } class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): @@ -333,13 +338,14 @@ class ControlDependencyTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines={ - "my_trt_op_0": ["c1", "add", "add1", "mul"], - "my_trt_op_1": ["c2", "add2", "add3", "mul1"] - }, - expected_output_dims=tuple(input_dims), - allclose_atol=1.e-06, - allclose_rtol=1.e-06) + expected_output_dims=tuple(input_dims)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return { + "my_trt_op_0": ["c1", "add", "add1", "mul"], + "my_trt_op_1": ["c2", "add2", "add3", "mul1"] + } if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py index 2e1107e303..070a30557d 100644 --- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/batch_matmul_test.py @@ -66,10 +66,40 @@ class BatchMatMulTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name, w1_name, w2_name], input_dims=[input_dims, w1_dims, w2_dims], - expected_engines=["my_trt_op_0"], - expected_output_dims=(12, 5, 8, 7), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(12, 5, 8, 7)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + if (run_params.dynamic_engine and + not trt_test.IsQuantizationMode(run_params.precision_mode)): + return ["my_trt_op_0", "my_trt_op_1"] + return ["my_trt_op_1"] + + def ExpectedEnginesToRun(self, run_params): + """Return the expected engines to run.""" + return ["my_trt_op_1"] + + def ShouldRunTest(self, run_params): + """Whether to run the test.""" + # TODO(aaroey): Trt library will fail like: + # + # ../builder/cudnnBuilder2.cpp:685: + # virtual std::vector<nvinfer1::query::Ports< + # nvinfer1::query::TensorRequirements>> + # nvinfer1::builder::Node::getSupportedFormats( + # const nvinfer1::query::Ports<nvinfer1::query::AbstractTensor>&, + # const nvinfer1::cudnn::HardwareContext&, + # nvinfer1::builder::Format::Type, + # const nvinfer1::builder::FormatTypeHack&) const: + # Assertion `sf' failed. + # + # To reproduce, run: + # bazel test -c opt --copt=-mavx \ + # --test_arg=BatchMatMulTest.testTfTrt_ToolConversion_INT8_DynamicEngine \ + # tensorflow/contrib/tensorrt:batch_matmul_test + # + # Investigate and fix it. + return not trt_test.IsQuantizationMode(run_params.precision_mode) if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 8be32f59b4..3e30acc231 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -102,13 +102,53 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=[ - "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", - "my_trt_op_4", "my_trt_op_5", "my_trt_op_6" - ], - expected_output_dims=(48, 89), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(48, 89)) + + def GetConversionParams(self, run_params): + """Return a ConversionParams for test.""" + return super(BiasaddMatMulTest, + self).GetConversionParams(run_params)._replace( + max_batch_size=48, maximum_cached_engines=2) + + def _ValidEngines(self): + """Engines expected to build and run.""" + return [ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_6", + "my_trt_op_7", "my_trt_op_8", "my_trt_op_9" + ] + + def _InvalidEngines(self): + """Engines that will cause conversion error at building time.""" + return ["my_trt_op_3", "my_trt_op_4", "my_trt_op_5"] + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + # In dynamic engine mode the engines are built in execution time, not in + # conversion time, so build errors occurs later. Here three of the engines + # will be failed to built but the corresponding engine op are still created. + # TODO(aaroey, jjsjann123): fix this. + if (run_params.dynamic_engine and + not trt_test.IsQuantizationMode(run_params.precision_mode)): + return self._ValidEngines() + self._InvalidEngines() + return self._ValidEngines() + + def ExpectedEnginesToRun(self, run_params): + """Return the expected engines to run.""" + return self._ValidEngines() + + def ShouldRunTest(self, run_params): + """Whether to run the test.""" + # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8 + # mode, which is a bug. Re-enable this when trt library is fixed. + return not trt_test.IsQuantizationMode(run_params.precision_mode) + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03 if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py index 9316b14da0..89ef6a5baf 100644 --- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py @@ -109,27 +109,28 @@ class BinaryTensorWeightBroadcastTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=[ - "my_trt_op_0", - "my_trt_op_1", - "my_trt_op_2", - "my_trt_op_3", - "my_trt_op_4", - "my_trt_op_5", - "my_trt_op_6", - "my_trt_op_7", - "my_trt_op_8", - "my_trt_op_9", - "my_trt_op_10", - "my_trt_op_11", - "my_trt_op_12", - "my_trt_op_13", - "my_trt_op_14", - "my_trt_op_15", - ], - expected_output_dims=(5, 23040), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(5, 23040)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return [ + "my_trt_op_0", + "my_trt_op_1", + "my_trt_op_2", + "my_trt_op_3", + "my_trt_op_4", + "my_trt_op_5", + "my_trt_op_6", + "my_trt_op_7", + "my_trt_op_8", + "my_trt_op_9", + "my_trt_op_10", + "my_trt_op_11", + "my_trt_op_12", + "my_trt_op_13", + "my_trt_op_14", + "my_trt_op_15", + ] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/contrib/tensorrt/test/concatenation_test.py index 1874b9dd45..c670b759dc 100644 --- a/tensorflow/contrib/tensorrt/test/concatenation_test.py +++ b/tensorflow/contrib/tensorrt/test/concatenation_test.py @@ -73,10 +73,11 @@ class ConcatenationTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=["my_trt_op_0"], - expected_output_dims=(2, 126), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(2, 126)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["my_trt_op_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py index 8c59000b70..d2d1d0e6dd 100644 --- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py +++ b/tensorflow/contrib/tensorrt/test/const_broadcast_test.py @@ -58,10 +58,19 @@ class ConstBroadcastTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=['my_trt_op_0'], - expected_output_dims=(5, 12, 12, 1), - allclose_atol=1.e-02, - allclose_rtol=1.e-02) + expected_output_dims=(5, 12, 12, 1)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ['my_trt_op_0'] + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-04 if run_params.precision_mode == 'FP32' else 1.e-02 if __name__ == '__main__': diff --git a/tensorflow/contrib/tensorrt/test/manual_test.py b/tensorflow/contrib/tensorrt/test/manual_test.py new file mode 100644 index 0000000000..60607681eb --- /dev/null +++ b/tensorflow/contrib/tensorrt/test/manual_test.py @@ -0,0 +1,125 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Basic tests for TF-TensorRT integration.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import ast +import numpy as np +import os + +from tensorflow.contrib.tensorrt.python import trt_convert +from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import gfile + + +class ManualTest(trt_test.TfTrtIntegrationTestBase): + + def __init__(self, methodName='runTest'): # pylint: disable=invalid-name + super(ManualTest, self).__init__(methodName) + self._params_map = None + + def _GetEnv(self): + """Get an environment variable specifying the manual test parameters. + + The value of the environment variable is the string representation of a dict + which should contain the following keys: + - 'graph_path': the file path to the serialized frozen graphdef + - 'input_names': TfTrtIntegrationTestParams.input_names + - 'input_dims': TfTrtIntegrationTestParams.input_dims + - 'expected_output_dims': TfTrtIntegrationTestParams.expected_output_dims + - 'output_name': the name of op to fetch + - 'expected_engines_to_run': ExpectedEnginesToRun() will return this + - 'expected_engines_to_build': ExpectedEnginesToBuild() will return this + - 'max_batch_size': ConversionParams.max_batch_size + """ + return os.getenv('TRT_MANUAL_TEST_PARAMS', '') + + def _GetParamsMap(self): + """Parse the environment variable as a dict and return it.""" + if self._params_map is None: + self._params_map = ast.literal_eval(self._GetEnv()) + return self._params_map + + @property + def output_name(self): + return self._GetParamsMap()['output_name'] + + def GetParams(self): + """Testing conversion of manually provided frozen graph.""" + params_map = self._GetParamsMap() + gdef = graph_pb2.GraphDef() + with gfile.Open(params_map['graph_path'], 'rb') as f: + gdef.ParseFromString(f.read()) + return trt_test.TfTrtIntegrationTestParams( + gdef=gdef, + input_names=params_map['input_names'], + input_dims=params_map['input_dims'], + expected_output_dims=params_map['expected_output_dims']) + + def GetConversionParams(self, run_params): + """Return a ConversionParams for test.""" + conversion_params = super(ManualTest, self).GetConversionParams(run_params) + params_map = self._GetParamsMap() + if 'max_batch_size' in params_map: + conversion_params = conversion_params._replace( + max_batch_size=params_map['max_batch_size']) + return conversion_params + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return self._GetParamsMap()['expected_engines_to_build'] + + def ExpectedEnginesToRun(self, run_params): + """Return the expected engines to run.""" + params_map = self._GetParamsMap() + if 'expected_engines_to_run' in params_map: + return params_map['expected_engines_to_run'] + return self.ExpectedEnginesToBuild(run_params) + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + params_map = self._GetParamsMap() + if 'atol' in params_map: + return params_map['atol'] + return 1.e-3 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + params_map = self._GetParamsMap() + if 'rtol' in params_map: + return params_map['rtol'] + return 1.e-3 + + def ShouldRunTest(self, run_params): + """Whether to run the test.""" + return len(self._GetEnv()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py index 66eb6be757..fd2c165f35 100644 --- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -62,10 +62,19 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=["my_trt_op_0"], - expected_output_dims=(2, 15, 15, 10), - allclose_atol=1.e-02, - allclose_rtol=1.e-02) + expected_output_dims=(2, 15, 15, 10)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["my_trt_op_0"] + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02 if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py index fd55b8cd99..13fdbcc5ad 100644 --- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py @@ -77,10 +77,11 @@ class MultiConnectionNeighborEngineTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=["my_trt_op_0", "my_trt_op_1"], - expected_output_dims=(2, 4, 5, 4), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(2, 4, 5, 4)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["my_trt_op_0", "my_trt_op_1"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py index 51c905a50b..d83f7278fc 100644 --- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py +++ b/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py @@ -59,13 +59,14 @@ class NeighboringEngineTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines={ - "my_trt_op_0": ["bias", "mul", "sub"], - "my_trt_op_1": ["weights", "conv"] - }, - expected_output_dims=(2, 4, 5, 4), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(2, 4, 5, 4)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return { + "my_trt_op_0": ["bias", "mul", "sub"], + "my_trt_op_1": ["weights", "conv"] + } if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/contrib/tensorrt/test/rank_two_test.py index fbed1ac4e8..9a9c919fca 100644 --- a/tensorflow/contrib/tensorrt/test/rank_two_test.py +++ b/tensorflow/contrib/tensorrt/test/rank_two_test.py @@ -35,10 +35,10 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): def GetParams(self): """Test for rank 2 input in TF-TRT.""" input_names = ["input", "input2"] + # Two paths: first with rank 2 input, second with rank 4 input. input_dims = [[12, 5], [12, 5, 2, 2]] g = ops.Graph() with g.as_default(): - # Path 1 with rank 2 input outputs = [] for i in range(2): x = array_ops.placeholder( @@ -56,26 +56,33 @@ class RankTwoTest(trt_test.TfTrtIntegrationTestBase): q = array_ops.expand_dims(q, -1, name="expand%d_%d" % (i, j)) q = gen_math_ops.reciprocal(q, name="reciprocal%d" % i) outputs.append(q) - # Combine path 1 & 2 + # Combine both paths q = math_ops.add(outputs[0], outputs[1], name="add") array_ops.squeeze(q, name=self.output_name) return trt_test.TfTrtIntegrationTestParams( gdef=g.as_graph_def(), input_names=input_names, input_dims=input_dims, - expected_engines={ - "my_trt_op_0": [ - "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1", - "abs0_2" - ], - "my_trt_op_1": [ - "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", - "abs1_1", "abs1_2", "reciprocal0", "reciprocal1" - ], - }, - expected_output_dims=tuple(input_dims[1]), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=tuple(input_dims[1])) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return { + "my_trt_op_0": [ + "add0_1", "add0_2", "add0_3", "c0_1", "c0_2", "c0_3", "abs0_1", + "abs0_2" + ], + "my_trt_op_1": [ + "add", "add1_1", "add1_2", "add1_3", "c1_1", "c1_2", "c1_3", + "abs1_1", "abs1_2", "reciprocal0", "reciprocal1" + ], + } + + def ShouldRunTest(self, run_params): + """Whether to run the test.""" + # TODO(aaroey): Trt 4.0 forbids conversion for tensors with rank <3 in int8 + # mode, which is a bug. Re-enable this when trt library is fixed. + return not trt_test.IsQuantizationMode(run_params.precision_mode) if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 6f85ada464..fc20950e45 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -38,19 +38,24 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops from tensorflow.python.platform import tf_logging as logging -TfTrtIntegrationTestParams = namedtuple("TfTrtIntegrationTestParams", [ - "gdef", "input_names", "input_dims", "expected_engines", - "expected_output_dims", "allclose_atol", "allclose_rtol" -]) +TfTrtIntegrationTestParams = namedtuple( + "TfTrtIntegrationTestParams", + ["gdef", "input_names", "input_dims", "expected_output_dims"]) RunParams = namedtuple( "RunParams", ["use_optimizer", "precision_mode", "dynamic_engine", "test_name"]) +ConversionParams = namedtuple("ConversionParams", [ + "max_batch_size", "max_workspace_size_bytes", "precision_mode", + "minimum_segment_size", "is_dynamic_op", "maximum_cached_engines", + "cached_engine_batches" +]) + PRECISION_MODES = ["FP32", "FP16", "INT8"] -def _IsQuantizationMode(mode): +def IsQuantizationMode(mode): return mode == "INT8" @@ -112,6 +117,10 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): super(TfTrtIntegrationTestBase, cls).setUpClass() trt_convert.enable_test_value() + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(TfTrtIntegrationTestBase, self).__init__(methodName) + self._trt_test_params = None + def setUp(self): """Setup method.""" super(TfTrtIntegrationTestBase, self).setUp() @@ -122,43 +131,96 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): """Return a TfTrtIntegrationTestParams for test, implemented by subclass.""" raise NotImplementedError() - def _PrepareRun(self, params, graph_state): + def GetConversionParams(self, run_params): + """Return a ConversionParams for test.""" + return ConversionParams( + max_batch_size=max( + [dims[0] for dims in self._GetParamsCached().input_dims]), + max_workspace_size_bytes=1 << 25, + precision_mode=self._ToBytes(run_params.precision_mode), + minimum_segment_size=2, + is_dynamic_op=run_params.dynamic_engine, + maximum_cached_engines=1, + cached_engine_batches=None) + + def ShouldRunTest(self, run_params): + """Whether to run the test.""" + return True + + def VerifyRunForEngine(self, engine_name, graph_state, expect_run=True): + """Verify the state of a particular engine after sess.run().""" + if graph_state == GraphState.ORIGINAL: + self._ExpectCalibration(engine_name, "") + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.CALIBRATE: + self._ExpectCalibration(engine_name, "done") + self._ExpectNativeSegment(engine_name, "done") + self._ExpectTrtEngine(engine_name, "") + elif graph_state == GraphState.INFERENCE: + self._ExpectCalibration(engine_name, "") + if expect_run: + self._ExpectNativeSegment(engine_name, "") + self._ExpectTrtEngine(engine_name, "done") + else: + self._ExpectNativeSegment(engine_name, "done") + self._ExpectTrtEngine(engine_name, "") + + def VerifyRun(self, run_params, graph_state): + """Verify the state of all engines after sess.run().""" + for engine_name in self.ExpectedEnginesToBuild(run_params): + expect_run = (engine_name in self.ExpectedEnginesToRun(run_params)) + self.VerifyRunForEngine(engine_name, graph_state, expect_run) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build, implemented by subclass.""" + raise NotImplementedError() + + def ExpectedEnginesToRun(self, run_params): + """Return the expected engines to run.""" + return self.ExpectedEnginesToBuild(run_params) + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-03 + + def _GetParamsCached(self): + if self._trt_test_params is None: + self._trt_test_params = self.GetParams() + return self._trt_test_params + + def _PrepareRun(self, graph_state): """Set up necessary testing environment before calling sess.run().""" # Clear test values added by TRTEngineOp. trt_convert.clear_test_values("my_trt_op_.*:ExecuteTrtEngine") trt_convert.clear_test_values("my_trt_op_.*:ExecuteCalibration") trt_convert.clear_test_values("my_trt_op_.*:ExecuteNativeSegment") - def _VerifyRun(self, params, graph_state): - """Verify the state after sess.run().""" - for engine_name in params.expected_engines: - if graph_state == GraphState.ORIGINAL: - self._ExpectCalibration(engine_name, "") - self._ExpectNativeSegment(engine_name, "") - self._ExpectTrtEngine(engine_name, "") - elif graph_state == GraphState.CALIBRATE: - self._ExpectCalibration(engine_name, "done") - self._ExpectNativeSegment(engine_name, "done") - self._ExpectTrtEngine(engine_name, "") - elif graph_state == GraphState.INFERENCE: - self._ExpectCalibration(engine_name, "") - self._ExpectNativeSegment(engine_name, "") - self._ExpectTrtEngine(engine_name, "done") - - def _GetConfigProto(self, params, run_params, graph_state): + def _GetConfigProto(self, run_params, graph_state): """Get config proto based on specific settings.""" if graph_state != GraphState.ORIGINAL and run_params.use_optimizer: rewriter_cfg = rewriter_config_pb2.RewriterConfig() rewriter_cfg.optimizers.extend(["constfold", "layout"]) custom_op = rewriter_cfg.custom_optimizers.add() custom_op.name = "TensorRTOptimizer" - custom_op.parameter_map["minimum_segment_size"].i = 2 - custom_op.parameter_map["max_batch_size"].i = max( - [dims[0] for dims in params.input_dims]) - custom_op.parameter_map["is_dynamic_op"].b = run_params.dynamic_engine - custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25 - custom_op.parameter_map["precision_mode"].s = self._ToBytes( - run_params.precision_mode) + trt_params = self.GetConversionParams(run_params) + custom_op.parameter_map["max_batch_size"].i = trt_params.max_batch_size + custom_op.parameter_map["max_workspace_size_bytes"].i = ( + trt_params.max_workspace_size_bytes) + custom_op.parameter_map["precision_mode"].s = trt_params.precision_mode + custom_op.parameter_map["minimum_segment_size"].i = ( + trt_params.minimum_segment_size) + custom_op.parameter_map["is_dynamic_op"].b = trt_params.is_dynamic_op + custom_op.parameter_map["maximum_cached_engines"].i = ( + trt_params.maximum_cached_engines) + if trt_params.cached_engine_batches: + optimizer.parameter_map["cached_engine_batches"].list.i.extend( + trt_params.cached_engine_batches) + graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_cfg) else: graph_options = config_pb2.GraphOptions() @@ -190,9 +252,15 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): def _ExpectNativeSegment(self, engine_name, value): self._ExpectTestValue(engine_name, "ExecuteNativeSegment", value) - def _RunGraph(self, params, gdef, input_data, config, graph_state, + def _RunGraph(self, + run_params, + gdef, + input_data, + config, + graph_state, num_runs=2): """Run given graphdef multiple times.""" + params = self._GetParamsCached() assert len(params.input_names) == len(input_data) g = ops.Graph() with g.as_default(): @@ -208,35 +276,38 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): val = None # Defaults to 2 runs to verify result across multiple runs is same. for _ in range(num_runs): - self._PrepareRun(params, graph_state) + self._PrepareRun(graph_state) new_val = sess.run(out, {inp[i]: input_data[i] for i in range(len(inp))}) self.assertEqual(params.expected_output_dims, new_val.shape) if val is not None: self.assertAllEqual(val, new_val) val = new_val - self._VerifyRun(params, graph_state) + self.VerifyRun(run_params, graph_state) return val # Use real data that is representative of the inference dataset # for calibration. For this test script it is random data. - def _RunCalibration(self, params, gdef, input_data, config): + def _RunCalibration(self, run_params, gdef, input_data, config): """Run calibration on given graph.""" return self._RunGraph( - params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) + run_params, gdef, input_data, config, GraphState.CALIBRATE, num_runs=5) - def _GetTrtGraphDef(self, params, run_params, gdef): + def _GetTrtGraphDef(self, run_params, gdef): """Return trt converted graphdef.""" + trt_params = self.GetConversionParams(run_params) return trt_convert.create_inference_graph( input_graph_def=gdef, outputs=[self.output_name], - max_batch_size=max([dims[0] for dims in params.input_dims]), - max_workspace_size_bytes=1 << 25, - precision_mode=run_params.precision_mode, - minimum_segment_size=2, - is_dynamic_op=run_params.dynamic_engine) - - def _WriteGraph(self, params, run_params, gdef, graph_state): + max_batch_size=trt_params.max_batch_size, + max_workspace_size_bytes=trt_params.max_workspace_size_bytes, + precision_mode=trt_params.precision_mode, + minimum_segment_size=trt_params.minimum_segment_size, + is_dynamic_op=trt_params.is_dynamic_op, + maximum_cached_engines=trt_params.maximum_cached_engines, + cached_engine_batches=trt_params.cached_engine_batches) + + def _WriteGraph(self, run_params, gdef, graph_state): if graph_state == GraphState.ORIGINAL: label = "Original" elif graph_state == GraphState.CALIBRATE: @@ -250,12 +321,13 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): logging.info("Writing graph to %s/%s", temp_dir, graph_name) graph_io.write_graph(gdef, temp_dir, graph_name) - def _VerifyConnections(self, params, converted_gdef): + def _VerifyConnections(self, expected_engines, converted_gdef): + params = self._GetParamsCached() old_to_new_node_map = { self._ToString(node.name): self._ToString(node.name) for node in params.gdef.node } - for engine_name, node_names in params.expected_engines.items(): + for engine_name, node_names in expected_engines.items(): for node_name in node_names: old_to_new_node_map[node_name] = engine_name name_to_node_map = { @@ -310,14 +382,16 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): msg="expected:\n%s\nvs actual:\n%s" % (sorted( expected_input_map.items()), sorted(actual_input_map.items()))) - def _VerifyGraphDef(self, params, run_params, gdef, graph_state): - self._WriteGraph(params, run_params, gdef, graph_state) + def _VerifyGraphDef(self, run_params, gdef, graph_state): + self._WriteGraph(run_params, gdef, graph_state) + params = self._GetParamsCached() + expected_engines = self.ExpectedEnginesToBuild(run_params) num_engines = 0 for node in gdef.node: if node.op == "TRTEngineOp": num_engines += 1 - self.assertTrue(node.name in params.expected_engines) + self.assertTrue(node.name in expected_engines) self.assertTrue(len(node.attr["serialized_segment"].s)) self.assertTrue(len(node.attr["segment_funcdef_name"].s)) self.assertEqual( @@ -328,7 +402,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): self.assertEqual(run_params.dynamic_engine, is_dynamic_engine) has_calibration_data = len(node.attr["calibration_data"].s) - if (_IsQuantizationMode(run_params.precision_mode) and + if (IsQuantizationMode(run_params.precision_mode) and graph_state == GraphState.INFERENCE): self.assertTrue(has_calibration_data) else: @@ -336,71 +410,70 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): if graph_state == GraphState.ORIGINAL: self.assertEqual(0, num_engines) else: - self.assertEqual(num_engines, len(params.expected_engines)) - if isinstance(params.expected_engines, dict): - self._VerifyConnections(params, gdef) + self.assertEqual(num_engines, len(expected_engines)) + if isinstance(expected_engines, dict): + self._VerifyConnections(expected_engines, gdef) # TODO(aaroey): consider verifying the corresponding TF function. - def RunTest(self, params, run_params): + def RunTest(self, run_params): + if not self.ShouldRunTest(run_params): + return assert run_params.precision_mode in PRECISION_MODES + params = self._GetParamsCached() input_data = [np.random.random_sample(dims) for dims in params.input_dims] input_gdef = params.gdef - self._VerifyGraphDef(params, run_params, input_gdef, GraphState.ORIGINAL) + self._VerifyGraphDef(run_params, input_gdef, GraphState.ORIGINAL) # Get reference result without running trt. - config_no_trt = self._GetConfigProto(params, run_params, - GraphState.ORIGINAL) + config_no_trt = self._GetConfigProto(run_params, GraphState.ORIGINAL) logging.info("Running original graph w/o trt, config:\n%s", str(config_no_trt)) - ref_result = self._RunGraph(params, input_gdef, input_data, config_no_trt, - GraphState.ORIGINAL) + ref_result = self._RunGraph(run_params, input_gdef, input_data, + config_no_trt, GraphState.ORIGINAL) # Run calibration if necessary. - if _IsQuantizationMode(run_params.precision_mode): + if IsQuantizationMode(run_params.precision_mode): - calib_config = self._GetConfigProto(params, run_params, - GraphState.CALIBRATE) + calib_config = self._GetConfigProto(run_params, GraphState.CALIBRATE) logging.info("Running calibration graph, config:\n%s", str(calib_config)) if run_params.use_optimizer: - result = self._RunCalibration(params, input_gdef, input_data, + result = self._RunCalibration(run_params, input_gdef, input_data, calib_config) else: - calib_gdef = self._GetTrtGraphDef(params, run_params, input_gdef) - self._VerifyGraphDef(params, run_params, calib_gdef, - GraphState.CALIBRATE) - result = self._RunCalibration(params, calib_gdef, input_data, + calib_gdef = self._GetTrtGraphDef(run_params, input_gdef) + self._VerifyGraphDef(run_params, calib_gdef, GraphState.CALIBRATE) + result = self._RunCalibration(run_params, calib_gdef, input_data, calib_config) - infer_gdef = trt_convert.calib_graph_to_infer_graph(calib_gdef) - self._VerifyGraphDef(params, run_params, infer_gdef, GraphState.INFERENCE) + infer_gdef = trt_convert.calib_graph_to_infer_graph( + calib_gdef, run_params.dynamic_engine) + self._VerifyGraphDef(run_params, infer_gdef, GraphState.INFERENCE) self.assertAllClose( ref_result, result, - atol=params.allclose_atol, - rtol=params.allclose_rtol) + atol=self.ExpectedAbsoluteTolerance(run_params), + rtol=self.ExpectedRelativeTolerance(run_params)) else: infer_gdef = input_gdef # Run inference. - infer_config = self._GetConfigProto(params, run_params, - GraphState.INFERENCE) + infer_config = self._GetConfigProto(run_params, GraphState.INFERENCE) logging.info("Running final inference graph, config:\n%s", str(infer_config)) if run_params.use_optimizer: - result = self._RunGraph(params, infer_gdef, input_data, infer_config, + result = self._RunGraph(run_params, infer_gdef, input_data, infer_config, GraphState.INFERENCE) else: - trt_infer_gdef = self._GetTrtGraphDef(params, run_params, infer_gdef) - self._VerifyGraphDef(params, run_params, trt_infer_gdef, - GraphState.INFERENCE) - result = self._RunGraph(params, trt_infer_gdef, input_data, infer_config, - GraphState.INFERENCE) + trt_infer_gdef = self._GetTrtGraphDef(run_params, infer_gdef) + self._VerifyGraphDef(run_params, trt_infer_gdef, GraphState.INFERENCE) + result = self._RunGraph(run_params, trt_infer_gdef, input_data, + infer_config, GraphState.INFERENCE) self.assertAllClose( ref_result, result, - atol=params.allclose_atol, - rtol=params.allclose_rtol) + atol=self.ExpectedAbsoluteTolerance(run_params), + rtol=self.ExpectedRelativeTolerance(run_params)) def testIdempotence(self): # Test that applying tensorrt optimizer or offline conversion tools multiple @@ -421,13 +494,12 @@ def _AddTests(test_class): """Gets a single test method based on the parameters.""" def _Test(self): - params = self.GetParams() logging.info( "Running test %s with parameters: use_optimizer=%s, " "precision_mode=%s, dynamic_engine=%s", "testTfTrt_" + run_params.test_name, run_params.use_optimizer, run_params.precision_mode, run_params.dynamic_engine) - self.RunTest(params, run_params) + self.RunTest(run_params) return _Test @@ -435,7 +507,7 @@ def _AddTests(test_class): dynamic_engine_options = [False, True] for (use_optimizer, precision_mode, dynamic_engine) in itertools.product( use_optimizer_options, PRECISION_MODES, dynamic_engine_options): - if _IsQuantizationMode(precision_mode): + if IsQuantizationMode(precision_mode): if use_optimizer: # TODO(aaroey): if use_optimizer is True we need to get the inference # graphdef using custom python wrapper class, which is not currently diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/contrib/tensorrt/test/unary_test.py index 500057a36d..5036bd7aaa 100644 --- a/tensorflow/contrib/tensorrt/test/unary_test.py +++ b/tensorflow/contrib/tensorrt/test/unary_test.py @@ -100,13 +100,14 @@ class UnaryTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name, input2_name], input_dims=[input_dims, input2_dims], - expected_engines=[ - "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", - "my_trt_op_4" - ], - expected_output_dims=(12, 5, 8, 12), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(12, 5, 8, 12)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return [ + "my_trt_op_0", "my_trt_op_1", "my_trt_op_2", "my_trt_op_3", + "my_trt_op_4" + ] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py index ab4d224db4..12f29ceebf 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py @@ -42,11 +42,9 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): with g.as_default(): x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) x, _, _ = nn_impl.fused_batch_norm( - x, - np.random.randn(2).astype(np.float32), - np.random.randn(2).astype(np.float32), - mean=np.random.randn(2).astype(np.float32), - variance=np.random.randn(2).astype(np.float32), + x, [1.0, 1.0], [0.0, 0.0], + mean=[0.5, 0.5], + variance=[1.0, 1.0], data_format="NCHW", is_training=False) e = constant_op.constant( @@ -72,10 +70,11 @@ class VGGBlockNCHWTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=["my_trt_op_0"], - expected_output_dims=(5, 6, 2, 2), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(5, 6, 2, 2)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["my_trt_op_0"] if __name__ == "__main__": diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/contrib/tensorrt/test/vgg_block_test.py index 56bdf848ea..129795bf98 100644 --- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py +++ b/tensorflow/contrib/tensorrt/test/vgg_block_test.py @@ -42,11 +42,9 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): with g.as_default(): x = array_ops.placeholder(dtype=dtype, shape=input_dims, name=input_name) x, _, _ = nn_impl.fused_batch_norm( - x, - np.random.randn(2).astype(np.float32), - np.random.randn(2).astype(np.float32), - mean=np.random.randn(2).astype(np.float32), - variance=np.random.randn(2).astype(np.float32), + x, [1.0, 1.0], [0.0, 0.0], + mean=[0.5, 0.5], + variance=[1.0, 1.0], is_training=False) e = constant_op.constant( np.random.randn(1, 1, 2, 6), name="weights", dtype=dtype) @@ -63,10 +61,11 @@ class VGGBlockTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=["my_trt_op_0"], - expected_output_dims=(5, 2, 2, 6), - allclose_atol=1.e-03, - allclose_rtol=1.e-03) + expected_output_dims=(5, 2, 2, 6)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["my_trt_op_0"] if __name__ == "__main__": |