diff options
author | Guangda Lai <31743510+aaroey@users.noreply.github.com> | 2018-09-18 21:40:01 -0700 |
---|---|---|
committer | Guangda Lai <31743510+aaroey@users.noreply.github.com> | 2018-09-18 21:40:01 -0700 |
commit | 65231a4c48ce3a1297d00e2a6310be05e79ed88c (patch) | |
tree | e6e104cf1a3cb6a2efce9ef754d113ee3ccd16e0 /tensorflow/contrib/tensorrt | |
parent | 9ee75bb6e29007b8b5ea4a6d981996d8a4d88373 (diff) |
Fix python3 tests
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r-- | tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py | 2 |
2 files changed, 1 insertions, 9 deletions
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py index 62f4e525f7..d2f65344da 100644 --- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py +++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py @@ -144,14 +144,6 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase): # mode, which is a bug. Re-enable this when trt library is fixed. return not trt_test.IsQuantizationMode(run_params.precision_mode) - def ExpectedAbsoluteTolerance(self, run_params): - """The absolute tolerance to compare floating point results.""" - return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03 - - def ExpectedRelativeTolerance(self, run_params): - """The relative tolerance to compare floating point results.""" - return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03 - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py index 699f79adec..4f935a7665 100644 --- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py +++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py @@ -134,7 +134,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase): dims[0] for dims in self._GetParamsCached().input_dims if len(dims) ]), max_workspace_size_bytes=1 << 25, - precision_mode=self._ToBytes(run_params.precision_mode), + precision_mode=run_params.precision_mode, minimum_segment_size=2, is_dynamic_op=run_params.dynamic_engine, maximum_cached_engines=1, |