aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/concat_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/concat_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/concat_op_test.py28
1 files changed, 14 insertions, 14 deletions
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index c22934ce47..0e59ce6972 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -383,7 +383,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
output = array_ops.concat(xs, 0)
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
@@ -397,7 +397,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
output = array_ops.concat(xs, 1)
err = gradient_checker.compute_gradient_error(xs, x_shapes, output,
@@ -411,7 +411,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 0)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -426,7 +426,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 1)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -441,7 +441,7 @@ class ConcatOpTest(test.TestCase):
np.random.random_sample(x_shape).astype(np.float64)
for x_shape in x_shapes
]
- with self.test_session():
+ with self.cached_session():
xs = [constant_op.constant(x_val) for x_val in x_vals]
x_concat = array_ops.concat(xs, 2)
output = array_ops.gather(x_concat, [1, 2, 0, 5])
@@ -452,7 +452,7 @@ class ConcatOpTest(test.TestCase):
def testIndexedSlicesConcatDim1Grad_UnknownInputDim(self):
x_shapes = [[20, 7, 3], [20, 3, 3], [20, 1, 3]]
output_shape = [4, 11, 3]
- with self.test_session():
+ with self.cached_session():
x_1 = array_ops.placeholder(dtypes.float64)
x_2 = array_ops.placeholder(dtypes.float64)
x_3 = array_ops.placeholder(dtypes.float64)
@@ -473,13 +473,13 @@ class ConcatOpTest(test.TestCase):
def testConcatTuple(self):
c1 = np.random.rand(4, 4)
c2 = np.random.rand(4, 4)
- with self.test_session():
+ with self.cached_session():
concat_list_t = array_ops.concat([c1, c2], 0)
concat_tuple_t = array_ops.concat((c1, c2), 0)
self.assertAllEqual(concat_list_t.eval(), concat_tuple_t.eval())
def testConcatNoScalars(self):
- with self.test_session():
+ with self.cached_session():
scalar = constant_op.constant(7)
dim = array_ops.placeholder(dtypes.int32)
with self.assertRaisesRegexp(
@@ -554,7 +554,7 @@ class ConcatOpTest(test.TestCase):
def _testGradientsForAxis(
self, inp_tensors, axis, output_shape, feed_dict=None):
- with self.test_session():
+ with self.cached_session():
c = array_ops.concat(inp_tensors, axis)
grad_inp = np.random.rand(*output_shape).astype("f")
grad_tensor = constant_op.constant(
@@ -566,7 +566,7 @@ class ConcatOpTest(test.TestCase):
def _testIndexedSlicesGradientsForAxis(
self, inp_tensors, axis, output_shape, gather_indexes, feed_dict=None):
- with self.test_session():
+ with self.cached_session():
c = array_ops.gather(
array_ops.concat(inp_tensors, axis), gather_indexes)
grad_inp = np.random.rand(*output_shape).astype("f")
@@ -631,7 +631,7 @@ class ConcatOffsetTest(test.TestCase):
self.assertAllEqual(ans, [[0, 0, 0], [0, 3, 0], [0, 10, 0]])
def testNotVector(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([[2, 3, 5]], dtypes.int32)
s1 = constant_op.constant([[2, 7, 5]], dtypes.int32)
@@ -641,7 +641,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testConcatDimOutOfRange(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(4, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5], dtypes.int32)
@@ -651,7 +651,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testDimMismatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 5, 10], dtypes.int32)
@@ -661,7 +661,7 @@ class ConcatOffsetTest(test.TestCase):
sess.run(off)
def testSizeMismatch(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
s1 = constant_op.constant([2, 7, 10], dtypes.int32)