aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/test/memory_alignment_test.py')
-rw-r--r--tensorflow/contrib/tensorrt/test/memory_alignment_test.py17
1 files changed, 13 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
index 66eb6be757..fd2c165f35 100644
--- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
+++ b/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
@@ -62,10 +62,19 @@ class MemoryAlignmentTest(trt_test.TfTrtIntegrationTestBase):
gdef=g.as_graph_def(),
input_names=[input_name],
input_dims=[input_dims],
- expected_engines=["my_trt_op_0"],
- expected_output_dims=(2, 15, 15, 10),
- allclose_atol=1.e-02,
- allclose_rtol=1.e-02)
+ expected_output_dims=(2, 15, 15, 10))
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["my_trt_op_0"]
+
+ def ExpectedAbsoluteTolerance(self, run_params):
+ """The absolute tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02
+
+ def ExpectedRelativeTolerance(self, run_params):
+ """The relative tolerance to compare floating point results."""
+ return 1.e-06 if run_params.precision_mode == "FP32" else 1.e-02
if __name__ == "__main__":