aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/where_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/where_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/where_op_test.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/where_op_test.py b/tensorflow/python/kernel_tests/where_op_test.py
index 29fb002ef4..04ac589432 100644
--- a/tensorflow/python/kernel_tests/where_op_test.py
+++ b/tensorflow/python/kernel_tests/where_op_test.py
@@ -30,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@@ -160,7 +161,7 @@ class WhereBenchmark(test.Benchmark):
x = random_ops.random_uniform((m, n), dtype=dtypes.float32) <= p
v = resource_variable_ops.ResourceVariable(x)
op = array_ops.where(v)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
v.initializer.run()
r = self.run_op_benchmark(sess, op, min_iters=100, name=name)
gb_processed_input = m * n / 1.0e9
@@ -186,7 +187,7 @@ class WhereBenchmark(test.Benchmark):
y = resource_variable_ops.ResourceVariable(y_gen)
c = resource_variable_ops.ResourceVariable(c_gen)
op = array_ops.where(c, x, y)
- with session.Session() as sess:
+ with session.Session(config=benchmark.benchmark_config()) as sess:
x.initializer.run()
y.initializer.run()
c.initializer.run()