aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-10 15:43:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-10 15:49:05 -0700
commite32029541ae270a021b266fcc3929b2528f8dff1 (patch)
tree40abaa2485e86b41b10d317af3969754e5cdb789 /tensorflow/python/framework
parent6951e0646d7dc8931b6cbe4388dcc3921249d462 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 212348850
Diffstat (limited to 'tensorflow/python/framework')
-rw-r--r--tensorflow/python/framework/file_system_test.py2
-rw-r--r--tensorflow/python/framework/function_test.py10
-rw-r--r--tensorflow/python/framework/importer_test.py18
-rw-r--r--tensorflow/python/framework/meta_graph_test.py9
-rw-r--r--tensorflow/python/framework/ops_test.py50
-rw-r--r--tensorflow/python/framework/sparse_tensor_test.py6
-rw-r--r--tensorflow/python/framework/subscribe_test.py14
-rw-r--r--tensorflow/python/framework/tensor_util_test.py2
8 files changed, 56 insertions, 55 deletions
diff --git a/tensorflow/python/framework/file_system_test.py b/tensorflow/python/framework/file_system_test.py
index 5eb59141a2..6901715e5d 100644
--- a/tensorflow/python/framework/file_system_test.py
+++ b/tensorflow/python/framework/file_system_test.py
@@ -37,7 +37,7 @@ class FileSystemTest(test.TestCase):
load_library.load_file_system_library(file_system_library)
def testBasic(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
reader = io_ops.WholeFileReader("test_reader")
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
queue.enqueue_many([["test://foo"]]).run()
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index ee723bacaf..903768a039 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -419,7 +419,7 @@ class FunctionTest(test.TestCase):
with ops.control_dependencies([z]):
return x * 2
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
z = Foo(constant_op.constant(3.0))
self.assertAllEqual(z.eval(), 6.0)
@@ -434,7 +434,7 @@ class FunctionTest(test.TestCase):
# Foo contains a stateful op (Assert).
self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
g = ops.Graph()
- with g.as_default(), self.test_session():
+ with g.as_default(), self.cached_session():
self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion failed.*-3"):
@@ -448,7 +448,7 @@ class FunctionTest(test.TestCase):
[control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]):
return array_ops.identity(x)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, MyFn(1.0).eval())
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
@@ -667,7 +667,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testNestedDefinedFunction(self):
@@ -683,7 +683,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testUnusedFunction(self):
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 18e7d8aa14..2b4d8e7299 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -396,7 +396,7 @@ class ImportGraphDefTest(test.TestCase):
# Run the imported graph.
# TODO(b/76173421): make this work (currently DCHECKS)
- # with self.test_session() as sess:
+ # with self.cached_session() as sess:
# sess.run(imported_init)
# self.assertEqual(sess.run(imported_var), 1.0)
# self.assertEqual(sess.run(imported_assign), 2.0)
@@ -417,7 +417,7 @@ class ImportGraphDefTest(test.TestCase):
imported_r, = importer.import_graph_def(graph_def,
return_elements=[r.name])
self.assertEqual(imported_r.name, "import/" + r.name)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(imported_r), 10)
def testImportWhileLoopInCond(self):
@@ -436,7 +436,7 @@ class ImportGraphDefTest(test.TestCase):
pred = array_ops.placeholder(dtypes.bool)
out = control_flow_ops.cond(pred, ImportFn,
lambda: constant_op.constant(1))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(out, {pred: True}), 10)
self.assertEqual(sess.run(out, {pred: False}), 1)
@@ -457,7 +457,7 @@ class ImportGraphDefTest(test.TestCase):
out = control_flow_ops.while_loop(
lambda i: i < 2, ImportFn, [0],
shape_invariants=[tensor_shape.TensorShape(None)])
- with self.test_session() as sess:
+ with self.cached_session() as sess:
self.assertEqual(sess.run(out), 10)
def testTypeMismatchInGraphDef(self):
@@ -929,7 +929,7 @@ class ImportGraphDefTest(test.TestCase):
input_map={"a:0": constant_op.constant(5.0)},
name="",
return_elements=["id:0"])
- with self.test_session():
+ with self.cached_session():
self.assertEqual(5.0, t.eval())
def testInvalidInputForReturnOperations(self):
@@ -958,7 +958,7 @@ class ImportGraphDefTest(test.TestCase):
array_ops.stack([c, c], name="pack")
gdef = g.as_graph_def()
- with self.test_session():
+ with self.cached_session():
pack, = importer.import_graph_def(gdef, return_elements=["pack"])
self.assertAllEqual(pack.outputs[0].eval(), [5.0, 5.0])
@@ -1063,7 +1063,7 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual([10], biases_grad.get_shape())
def testLargeGraph(self):
- with self.test_session():
+ with self.cached_session():
# The default message byte limit is 64M. Ours is 2G with a warning at 512.
# Adding a 130M entries float32 tensor should exceed the warning, but not
# the hard limit.
@@ -1254,7 +1254,7 @@ class ImportGraphDefTest(test.TestCase):
z = TestFunc()
- with self.test_session():
+ with self.cached_session():
z_val = z.eval()
self.assertEqual(z_val, -2.0)
@@ -1284,7 +1284,7 @@ class ImportGraphDefTest(test.TestCase):
z2 = importer.import_graph_def(gdef, return_elements=["z:0"],
input_map=input_map)[0]
- with self.test_session() as sess:
+ with self.cached_session() as sess:
z1_val, z2_val = sess.run((z1, z2))
self.assertAllEqual(z1_val, z2_val)
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 6e5f7aafac..fc98b91a01 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -117,7 +117,7 @@ class SimpleMetaGraphTest(test.TestCase):
self.assertEqual(new_output_value, output_value)
def testStrippedOpListNestedFunctions(self):
- with self.test_session():
+ with self.cached_session():
# Square two levels deep
@function.Defun(dtypes.int32)
def f0(x):
@@ -169,7 +169,7 @@ class SimpleMetaGraphTest(test.TestCase):
# and "Tout" maps to complex64. Since these attr values map to their
# defaults, they must be stripped unless stripping of default attrs is
# disabled.
- with self.test_session():
+ with self.cached_session():
real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
@@ -212,7 +212,8 @@ class SimpleMetaGraphTest(test.TestCase):
def testDefaultAttrStrippingNestedFunctions(self):
"""Verifies that default attributes are stripped from function node defs."""
- with self.test_session():
+ with self.cached_session():
+
@function.Defun(dtypes.float32, dtypes.float32)
def f0(i, j):
return math_ops.complex(i, j, name="double_nested_complex")
@@ -251,7 +252,7 @@ class SimpleMetaGraphTest(test.TestCase):
meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
meta_info_def.stripped_op_list.op.add()
- with self.test_session():
+ with self.cached_session():
meta_graph_def = meta_graph.create_meta_graph_def(
meta_info_def=meta_info_def, graph_def=graph_def,
strip_default_attrs=True)
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index ced0581402..d59adf3d48 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -58,12 +58,12 @@ ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
class ResourceTest(test_util.TensorFlowTestCase):
def testBuildGraph(self):
- with self.test_session():
+ with self.cached_session():
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
test_ops.resource_create_op(pt).run()
def testInitialize(self):
- with self.test_session():
+ with self.cached_session():
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
resources.register_resource(
handle=handle,
@@ -100,35 +100,35 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
pass
def testAddShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.zeros([2, 3])
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual([2, 3], c.shape)
def testUnknownDim(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
c = a + b
self.assertEqual([2, None, 3], c.shape.as_list())
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
def testScalarShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
b = array_ops.ones([])
c = a + b
self.assertEqual(tensor_shape.scalar(), c.shape)
def testShapeFunctionError(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.ones([1, 2, 3])
b = array_ops.ones([4, 5, 6])
with self.assertRaisesRegexp(
@@ -141,7 +141,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
class IndexedSlicesTest(test_util.TensorFlowTestCase):
def testToTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
dense_shape = constant_op.constant([3, 2])
@@ -150,7 +150,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
def testNegation(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = -ops.IndexedSlices(values, indices)
@@ -158,7 +158,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.indices.eval(), [0, 2])
def testScalarMul(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
@@ -307,14 +307,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
def testConvertToTensorNestedArray(self):
- with self.test_session():
+ with self.cached_session():
values = [[2], [3], [5], [7]]
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, tensor.eval())
def testShapeTuple(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(1)
self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
@@ -328,14 +328,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(converted, ops.EagerTensor))
def testConvertToTensorNestedTuple(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, ops.convert_to_tensor(values).eval())
def testConvertToTensorNestedTensors(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(
[constant_op.constant(row) for row in values])
@@ -347,25 +347,25 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(values, tensor.eval())
def testConvertToTensorNestedMix(self):
- with self.test_session():
+ with self.cached_session():
values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval())
def testConvertToTensorPreferred(self):
- with self.test_session():
+ with self.cached_session():
values = [2, 3, 5, 7]
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
self.assertEqual(dtypes.float32, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# Convert empty tensor to anything.
values = []
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
self.assertEqual(dtypes.int64, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# The preferred dtype is a type error and will convert to
# float32 instead.
values = [1.23]
@@ -941,7 +941,7 @@ class NameStackTest(test_util.TensorFlowTestCase):
self.assertEqual("bar_2", g.unique_name("bar"))
def testNameAndVariableScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.name_scope("l0"):
with variable_scope.variable_scope("l1"):
with sess.graph.name_scope("l1") as scope:
@@ -2164,7 +2164,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
g = ops.Graph()
with g.as_default():
- with self.test_session():
+ with self.cached_session():
# First ensure that graphs that are not building functions are
# not escaped.
function_with_variables("foo")
@@ -2416,11 +2416,11 @@ class AttrScopeTest(test_util.TensorFlowTestCase):
return (a, b)
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((None, None), self._get_test_attrs())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a1 = self._get_test_attrs()
with sess.graph._attr_scope({
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
@@ -2454,12 +2454,12 @@ ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
class KernelLabelTest(test_util.TensorFlowTestCase):
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(b"My label is: default",
test_ops.kernel_label().eval())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_1 = test_ops.kernel_label()
# pylint: disable=protected-access
with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
@@ -2900,7 +2900,7 @@ class NameScopeTest(test_util.TensorFlowTestCase):
class TracebackTest(test_util.TensorFlowTestCase):
def testTracebackWithStartLines(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(2.0)
sess.run(
a,
diff --git a/tensorflow/python/framework/sparse_tensor_test.py b/tensorflow/python/framework/sparse_tensor_test.py
index 2bcfbc17df..22423c4f58 100644
--- a/tensorflow/python/framework/sparse_tensor_test.py
+++ b/tensorflow/python/framework/sparse_tensor_test.py
@@ -45,7 +45,7 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
self.assertEqual(sp.dense_shape.dtype, dtypes.int64)
self.assertEqual(sp.get_shape(), (4, 5))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
value = sp.eval()
self.assertAllEqual(indices, value.indices)
self.assertAllEqual(values, value.values)
@@ -81,14 +81,14 @@ class SparseTensorTest(test_util.TensorFlowTestCase):
class ConvertToTensorOrSparseTensorTest(test_util.TensorFlowTestCase):
def test_convert_dense(self):
- with self.test_session():
+ with self.cached_session():
value = [42, 43]
from_value = sparse_tensor.convert_to_tensor_or_sparse_tensor(
value)
self.assertAllEqual(value, from_value.eval())
def test_convert_sparse(self):
- with self.test_session():
+ with self.cached_session():
indices = [[0, 1], [1, 0]]
values = [42, 43]
shape = [2, 2]
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index d6de45fdc4..1d594e4078 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -65,7 +65,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertFalse(c0.op in d.op.control_inputs)
self.assertTrue(c.op in d.op.control_inputs)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c_out = sess.run([c])
n_out = sess.run([n])
d_out = sess.run([d])
@@ -144,7 +144,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
b = subscribe.subscribe(b,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
c_out = sess.run([c])
d_out = sess.run([d])
@@ -204,7 +204,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIs(c_sub, c_sub3)
# Expect the three side effect graphs to have been evaluated.
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([c_sub])
self.assertIn('graph1', shared)
self.assertIn('graph2', shared)
@@ -227,7 +227,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
v1, lambda t: script_ops.py_func(sub, [t], [t.dtype]))
self.assertTrue(subscribe._is_subscribed_identity(v1_sub))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# Initialize the variables first.
sess.run([v1.initializer])
sess.run([v2.initializer])
@@ -272,7 +272,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
self.assertIs(tensor_array_sub, tensor_array.handle)
self.assertFalse(subscribe._is_subscribed_identity(tensor_array.handle))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([reader])
self.assertEqual(0, len(shared))
@@ -303,7 +303,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
subscribe.subscribe(sparse_add.op.outputs,
lambda t: script_ops.py_func(sub, [t], [t.dtype]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run([neg])
# All three ops have been processed.
@@ -374,7 +374,7 @@ class SubscribeTest(test_util.TensorFlowTestCase):
# Verify that sub(x1) and sub(branch) are not.
self.assertIsNot(context(subscriptions[0]), context(subscriptions[1]))
- with self.test_session() as sess:
+ with self.cached_session() as sess:
sess.run(cond)
self.assertEqual(3, len(results))
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index 395cf43b3f..bdf759f220 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -768,7 +768,7 @@ class TensorUtilTest(test.TestCase):
def __array__(self, dtype=None):
return np.asarray(self.array, dtype)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
ma = MockArray(np.array([10, 20, 30]))
t = ops.convert_to_tensor(ma)
a = sess.run(t)