diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/memory_alignment_test.py')
-rw-r--r-- | tensorflow/contrib/tensorrt/test/memory_alignment_test.py | 17 |
1 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py index 66eb6be757..fd2c165f35 100644 --- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py +++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py @@ -62,10 +62,19 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase): gdef=g.as_graph_def(), input_names=[input_name], input_dims=[input_dims], - expected_engines=["my_trt_op_0"], - expected_output_dims=(2, 15, 15, 10), - allclose_atol=1.e-02, - allclose_rtol=1.e-02) + expected_output_dims=(2, 15, 15, 10)) + + def ExpectedEnginesToBuild(self, run_params): + """Return the expected engines to build.""" + return ["my_trt_op_0"] + + def ExpectedAbsoluteTolerance(self, run_params): + """The absolute tolerance to compare floating point results.""" + return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02 + + def ExpectedRelativeTolerance(self, run_params): + """The relative tolerance to compare floating point results.""" + return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02 if __name__ == "__main__": |