diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/where_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/where_op_test.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py index 29fb002ef4..04ac589432 100644 --- a/tensorflow/python/kernel_tests/where_op_test.py +++ b/tensorflow/python/kernel_tests/where_op_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import benchmark from tensorflow.python.platform import test @@ -160,7 +161,7 @@ class WhereBenchmark(test.Benchmark): x = random_ops.random_uniform((m, n), dtype=dtypes.float32) <= p v = resource_variable_ops.ResourceVariable(x) op = array_ops.where(v) - with session.Session() as sess: + with session.Session(config=benchmark.benchmark_config()) as sess: v.initializer.run() r = self.run_op_benchmark(sess, op, min_iters=100, name=name) gb_processed_input = m * n / 1.0e9 @@ -186,7 +187,7 @@ class WhereBenchmark(test.Benchmark): y = resource_variable_ops.ResourceVariable(y_gen) c = resource_variable_ops.ResourceVariable(c_gen) op = array_ops.where(c, x, y) - with session.Session() as sess: + with session.Session(config=benchmark.benchmark_config()) as sess: x.initializer.run() y.initializer.run() c.initializer.run() |