diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/matrix_exponential_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/matrix_exponential_op_test.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py index 0386e91276..9630c052b8 100644 --- a/tensorflow/python/kernel_tests/matrix_exponential_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_exponential_op_test.py @@ -30,6 +30,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables from tensorflow.python.ops.linalg import linalg_impl +from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -181,7 +182,7 @@ class MatrixExponentialBenchmark(test.Benchmark): def benchmarkMatrixExponentialOp(self): for shape in self.shapes: with ops.Graph().as_default(), \ - session.Session() as sess, \ + session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device("/cpu:0"): matrix = self._GenerateMatrix(shape) expm = linalg_impl.matrix_exponential(matrix) @@ -195,7 +196,7 @@ class MatrixExponentialBenchmark(test.Benchmark): if test.is_gpu_available(True): with ops.Graph().as_default(), \ - session.Session() as sess, \ + session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device("/gpu:0"): matrix = self._GenerateMatrix(shape) expm = linalg_impl.matrix_exponential(matrix) |