aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt
diff options
context:
space:
mode:
authorGravatar Guangda Lai <31743510+aaroey@users.noreply.github.com>2018-09-18 21:40:01 -0700
committerGravatar Guangda Lai <31743510+aaroey@users.noreply.github.com>2018-09-18 21:40:01 -0700
commit65231a4c48ce3a1297d00e2a6310be05e79ed88c (patch)
treee6e104cf1a3cb6a2efce9ef754d113ee3ccd16e0 /tensorflow/contrib/tensorrt
parent9ee75bb6e29007b8b5ea4a6d981996d8a4d88373 (diff)
Fix python3 tests
Diffstat (limited to 'tensorflow/contrib/tensorrt')
-rw-r--r--tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py8
-rw-r--r--tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py2
2 files changed, 1 insertions, 9 deletions
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
index 62f4e525f7..d2f65344da 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
@@ -144,14 +144,6 @@ class BiasaddMatMulTest(trt_test.TfTrtIntegrationTestBase):
# mode, which is a bug. Re-enable this when trt library is fixed.
return not trt_test.IsQuantizationMode(run_params.precision_mode)
- def ExpectedAbsoluteTolerance(self, run_params):
- """The absolute tolerance to compare floating point results."""
- return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
-
- def ExpectedRelativeTolerance(self, run_params):
- """The relative tolerance to compare floating point results."""
- return 1.e-05 if run_params.precision_mode == "FP32" else 1.e-03
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
index 699f79adec..4f935a7665 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
@@ -134,7 +134,7 @@ class TfTrtIntegrationTestBase(test_util.TensorFlowTestCase):
dims[0] for dims in self._GetParamsCached().input_dims if len(dims)
]),
max_workspace_size_bytes=1 << 25,
- precision_mode=self._ToBytes(run_params.precision_mode),
+ precision_mode=run_params.precision_mode,
minimum_segment_size=2,
is_dynamic_op=run_params.dynamic_engine,
maximum_cached_engines=1,