diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/concat_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/concat_op_test.py | 28 |
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py index c22934ce47..0e59ce6972 100644 --- a/tensorflow/python/kernel_tests/concat_op_test.py +++ b/tensorflow/python/kernel_tests/concat_op_test.py @@ -383,7 +383,7 @@ class ConcatOpTest(test.TestCase): np.random.random_sample(x_shape).astype(np.float64) for x_shape in x_shapes ] - with self.test_session(): + with self.cached_session(): xs = [constant_op.constant(x_val) for x_val in x_vals] output = array_ops.concat(xs, 0) err = gradient_checker.compute_gradient_error(xs, x_shapes, output, @@ -397,7 +397,7 @@ class ConcatOpTest(test.TestCase): np.random.random_sample(x_shape).astype(np.float64) for x_shape in x_shapes ] - with self.test_session(): + with self.cached_session(): xs = [constant_op.constant(x_val) for x_val in x_vals] output = array_ops.concat(xs, 1) err = gradient_checker.compute_gradient_error(xs, x_shapes, output, @@ -411,7 +411,7 @@ class ConcatOpTest(test.TestCase): np.random.random_sample(x_shape).astype(np.float64) for x_shape in x_shapes ] - with self.test_session(): + with self.cached_session(): xs = [constant_op.constant(x_val) for x_val in x_vals] x_concat = array_ops.concat(xs, 0) output = array_ops.gather(x_concat, [1, 2, 0, 5]) @@ -426,7 +426,7 @@ class ConcatOpTest(test.TestCase): np.random.random_sample(x_shape).astype(np.float64) for x_shape in x_shapes ] - with self.test_session(): + with self.cached_session(): xs = [constant_op.constant(x_val) for x_val in x_vals] x_concat = array_ops.concat(xs, 1) output = array_ops.gather(x_concat, [1, 2, 0, 5]) @@ -441,7 +441,7 @@ class ConcatOpTest(test.TestCase): np.random.random_sample(x_shape).astype(np.float64) for x_shape in x_shapes ] - with self.test_session(): + with self.cached_session(): xs = [constant_op.constant(x_val) for x_val in x_vals] x_concat = array_ops.concat(xs, 2) output = array_ops.gather(x_concat, [1, 2, 0, 5]) @@ -452,7 +452,7 @@ class ConcatOpTest(test.TestCase): def testIndexedSlicesConcatDim1Grad_UnknownInputDim(self): x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]] output_shape = [4, 11, 3] - with self.test_session(): + with self.cached_session(): x_1 = array_ops.placeholder(dtypes.float64) x_2 = array_ops.placeholder(dtypes.float64) x_3 = array_ops.placeholder(dtypes.float64) @@ -473,13 +473,13 @@ class ConcatOpTest(test.TestCase): def testConcatTuple(self): c1 = np.random.rand(4, 4) c2 = np.random.rand(4, 4) - with self.test_session(): + with self.cached_session(): concat_list_t = array_ops.concat([c1, c2], 0) concat_tuple_t = array_ops.concat((c1, c2), 0) self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval()) def testConcatNoScalars(self): - with self.test_session(): + with self.cached_session(): scalar = constant_op.constant(7) dim = array_ops.placeholder(dtypes.int32) with self.assertRaisesRegexp( @@ -554,7 +554,7 @@ class ConcatOpTest(test.TestCase): def _testGradientsForAxis( self, inp_tensors, axis, output_shape, feed_dict=None): - with self.test_session(): + with self.cached_session(): c = array_ops.concat(inp_tensors, axis) grad_inp = np.random.rand(*output_shape).astype("f") grad_tensor = constant_op.constant( @@ -566,7 +566,7 @@ class ConcatOpTest(test.TestCase): def _testIndexedSlicesGradientsForAxis( self, inp_tensors, axis, output_shape, gather_indexes, feed_dict=None): - with self.test_session(): + with self.cached_session(): c = array_ops.gather( array_ops.concat(inp_tensors, axis), gather_indexes) grad_inp = np.random.rand(*output_shape).astype("f") @@ -631,7 +631,7 @@ class ConcatOffsetTest(test.TestCase): self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]]) def testNotVector(self): - with self.test_session() as sess: + with self.cached_session() as sess: cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([[2, 3, 5]], dtypes.int32) s1 = constant_op.constant([[2, 7, 5]], dtypes.int32) @@ -641,7 +641,7 @@ class ConcatOffsetTest(test.TestCase): sess.run(off) def testConcatDimOutOfRange(self): - with self.test_session() as sess: + with self.cached_session() as sess: cdim = constant_op.constant(4, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5], dtypes.int32) @@ -651,7 +651,7 @@ class ConcatOffsetTest(test.TestCase): sess.run(off) def testDimMismatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 5, 10], dtypes.int32) @@ -661,7 +661,7 @@ class ConcatOffsetTest(test.TestCase): sess.run(off) def testSizeMismatch(self): - with self.test_session() as sess: + with self.cached_session() as sess: cdim = constant_op.constant(1, dtypes.int32) s0 = constant_op.constant([2, 3, 5], dtypes.int32) s1 = constant_op.constant([2, 7, 10], dtypes.int32) |