aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_partition_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/dynamic_partition_op_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
index 5e8937ad2c..9557e30993 100644
--- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
+++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py
@@ -288,7 +288,7 @@ class DynamicPartitionTest(test.TestCase):
self.assertAllEqual([], partition_vals[i])
def testErrorIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
data = constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11],
[12, 13, 14]])
indices = constant_op.constant([0, 2, 99, 2, 2])
@@ -298,7 +298,7 @@ class DynamicPartitionTest(test.TestCase):
sess.run(partitions)
def testScalarIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
bad = 17
data = np.zeros(5)
partitions = data_flow_ops.dynamic_partition(data, bad, num_partitions=7)
@@ -306,7 +306,7 @@ class DynamicPartitionTest(test.TestCase):
sess.run(partitions)
def testHigherRankIndexOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
shape = (2, 3)
indices = array_ops.placeholder(shape=shape, dtype=np.int32)
data = np.zeros(shape + (5,))
@@ -334,7 +334,7 @@ class DynamicPartitionTest(test.TestCase):
inds += [13]*194 + [14]*194 + [15]*192
self.assertEqual(len(inds), x.shape[0])
partitioned = data_flow_ops.dynamic_partition(x, inds, 16)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
res = sess.run(partitioned)
self.assertEqual(res[-1].shape[0], 192)