diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/reduce_join_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/reduce_join_op_test.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/reduce_join_op_test.py b/tensorflow/python/kernel_tests/reduce_join_op_test.py index 663561ced7..3bb4986313 100644 --- a/tensorflow/python/kernel_tests/reduce_join_op_test.py +++ b/tensorflow/python/kernel_tests/reduce_join_op_test.py @@ -113,7 +113,7 @@ class ReduceJoinTest(UnicodeTestCase): keep_dims: Whether or not to retain reduced dimensions. separator: The separator to use for joining. """ - with self.test_session(): + with self.cached_session(): output = string_ops.reduce_join( inputs=input_array, axis=axis, @@ -136,7 +136,7 @@ class ReduceJoinTest(UnicodeTestCase): axis: The indices to reduce. separator: The separator to use when joining. """ - with self.test_session(): + with self.cached_session(): output = string_ops.reduce_join( inputs=input_array, axis=axis, keep_dims=False, separator=separator) output_keep_dims = string_ops.reduce_join( @@ -234,7 +234,7 @@ class ReduceJoinTest(UnicodeTestCase): input_array = [["a"], ["b"]] truth = ["ab"] truth_shape = None - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(dtypes.string, name="placeholder") reduced = string_ops.reduce_join(placeholder, axis=0) output_array = reduced.eval(feed_dict={placeholder.name: input_array}) @@ -247,7 +247,7 @@ class ReduceJoinTest(UnicodeTestCase): truth_dim_zero = ["thisplease", "isdo", "anot", "testpanic"] truth_dim_one = ["thisisatest", "pleasedonotpanic"] truth_shape = None - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(dtypes.int32, name="placeholder") reduced = string_ops.reduce_join(input_array, axis=placeholder) output_array_dim_zero = reduced.eval(feed_dict={placeholder.name: [0]}) @@ -298,7 +298,7 @@ class ReduceJoinTest(UnicodeTestCase): self._testMultipleReduceJoin(input_array, axis=permutation) def testInvalidReductionIndices(self): - with self.test_session(): + with self.cached_session(): with self.assertRaisesRegexp(ValueError, "Invalid reduction dim"): string_ops.reduce_join(inputs="", axis=0) with self.assertRaisesRegexp(ValueError, @@ -313,7 +313,7 @@ class ReduceJoinTest(UnicodeTestCase): string_ops.reduce_join(inputs=[[""]], axis=[0, 2]) def testZeroDims(self): - with self.test_session(): + with self.cached_session(): inputs = np.zeros([0, 1], dtype=str) # Reduction that drops the dim of size 0. @@ -326,7 +326,7 @@ class ReduceJoinTest(UnicodeTestCase): self.assertAllEqual([0], output_shape) def testInvalidArgsUnknownShape(self): - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(dtypes.string, name="placeholder") index_too_high = string_ops.reduce_join(placeholder, axis=1) duplicate_index = string_ops.reduce_join(placeholder, axis=[-1, 1]) @@ -336,7 +336,7 @@ class ReduceJoinTest(UnicodeTestCase): duplicate_index.eval(feed_dict={placeholder.name: [[""]]}) def testInvalidArgsUnknownIndices(self): - with self.test_session(): + with self.cached_session(): placeholder = array_ops.placeholder(dtypes.int32, name="placeholder") reduced = string_ops.reduce_join(["test", "test2"], axis=placeholder) |