aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/ops_test.py')
-rw-r--r--tensorflow/python/framework/ops_test.py50
1 files changed, 25 insertions, 25 deletions
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index ced0581402..d59adf3d48 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -58,12 +58,12 @@ ops._set_call_cpp_shape_fn(common_shapes.call_cpp_shape_fn)
class ResourceTest(test_util.TensorFlowTestCase):
def testBuildGraph(self):
- with self.test_session():
+ with self.cached_session():
pt = test_ops.stub_resource_handle_op(container="a", shared_name="b")
test_ops.resource_create_op(pt).run()
def testInitialize(self):
- with self.test_session():
+ with self.cached_session():
handle = test_ops.stub_resource_handle_op(container="a", shared_name="b")
resources.register_resource(
handle=handle,
@@ -100,35 +100,35 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
pass
def testAddShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.zeros([2, 3])
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual([2, 3], c.shape)
def testUnknownDim(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
b = array_ops.placeholder(dtype=dtypes.float32, shape=[2, None, 3])
c = a + b
self.assertEqual([2, None, 3], c.shape.as_list())
def testUnknownShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=None)
b = array_ops.ones([1, 3])
c = a + b
self.assertEqual(tensor_shape.unknown_shape(), c.shape)
def testScalarShape(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.placeholder(dtype=dtypes.float32, shape=[])
b = array_ops.ones([])
c = a + b
self.assertEqual(tensor_shape.scalar(), c.shape)
def testShapeFunctionError(self):
- with self.test_session():
+ with self.cached_session():
a = array_ops.ones([1, 2, 3])
b = array_ops.ones([4, 5, 6])
with self.assertRaisesRegexp(
@@ -141,7 +141,7 @@ class TensorAndShapeTest(test_util.TensorFlowTestCase):
class IndexedSlicesTest(test_util.TensorFlowTestCase):
def testToTensor(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
dense_shape = constant_op.constant([3, 2])
@@ -150,7 +150,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tensor.eval(), [[2, 3], [0, 0], [5, 7]])
def testNegation(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = -ops.IndexedSlices(values, indices)
@@ -158,7 +158,7 @@ class IndexedSlicesTest(test_util.TensorFlowTestCase):
self.assertAllEqual(x.indices.eval(), [0, 2])
def testScalarMul(self):
- with self.test_session():
+ with self.cached_session():
values = constant_op.constant([2, 3, 5, 7], shape=[2, 2])
indices = constant_op.constant([0, 2])
x = math_ops.scalar_mul(-2, ops.IndexedSlices(values, indices))
@@ -307,14 +307,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertEqual(tensor_shape.unknown_shape(), op.get_shape())
def testConvertToTensorNestedArray(self):
- with self.test_session():
+ with self.cached_session():
values = [[2], [3], [5], [7]]
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, tensor.eval())
def testShapeTuple(self):
- with self.test_session():
+ with self.cached_session():
c = constant_op.constant(1)
self.assertEqual(c._shape_tuple(), ()) # pylint: disable=protected-access
@@ -328,14 +328,14 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertTrue(isinstance(converted, ops.EagerTensor))
def testConvertToTensorNestedTuple(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(values, ops.convert_to_tensor(values).eval())
def testConvertToTensorNestedTensors(self):
- with self.test_session():
+ with self.cached_session():
values = ((2,), (3,), (5,), (7,))
tensor = ops.convert_to_tensor(
[constant_op.constant(row) for row in values])
@@ -347,25 +347,25 @@ class OperationTest(test_util.TensorFlowTestCase):
self.assertAllEqual(values, tensor.eval())
def testConvertToTensorNestedMix(self):
- with self.test_session():
+ with self.cached_session():
values = ([2], (3,), [constant_op.constant(5)], constant_op.constant([7]))
tensor = ops.convert_to_tensor(values)
self.assertAllEqual((4, 1), tensor.get_shape().as_list())
self.assertAllEqual(((2,), (3,), (5,), (7,)), tensor.eval())
def testConvertToTensorPreferred(self):
- with self.test_session():
+ with self.cached_session():
values = [2, 3, 5, 7]
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.float32)
self.assertEqual(dtypes.float32, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# Convert empty tensor to anything.
values = []
tensor = ops.convert_to_tensor(values, preferred_dtype=dtypes.int64)
self.assertEqual(dtypes.int64, tensor.dtype)
- with self.test_session():
+ with self.cached_session():
# The preferred dtype is a type error and will convert to
# float32 instead.
values = [1.23]
@@ -941,7 +941,7 @@ class NameStackTest(test_util.TensorFlowTestCase):
self.assertEqual("bar_2", g.unique_name("bar"))
def testNameAndVariableScope(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
with sess.graph.name_scope("l0"):
with variable_scope.variable_scope("l1"):
with sess.graph.name_scope("l1") as scope:
@@ -2164,7 +2164,7 @@ class InitScopeTest(test_util.TensorFlowTestCase):
g = ops.Graph()
with g.as_default():
- with self.test_session():
+ with self.cached_session():
# First ensure that graphs that are not building functions are
# not escaped.
function_with_variables("foo")
@@ -2416,11 +2416,11 @@ class AttrScopeTest(test_util.TensorFlowTestCase):
return (a, b)
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual((None, None), self._get_test_attrs())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a1 = self._get_test_attrs()
with sess.graph._attr_scope({
"_A": attr_value_pb2.AttrValue(s=compat.as_bytes("foo"))
@@ -2454,12 +2454,12 @@ ops.RegisterShape("KernelLabel")(common_shapes.scalar_shape)
class KernelLabelTest(test_util.TensorFlowTestCase):
def testNoLabel(self):
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(b"My label is: default",
test_ops.kernel_label().eval())
def testLabelMap(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
default_1 = test_ops.kernel_label()
# pylint: disable=protected-access
with sess.graph._kernel_label_map({"KernelLabel": "overload_1"}):
@@ -2900,7 +2900,7 @@ class NameScopeTest(test_util.TensorFlowTestCase):
class TracebackTest(test_util.TensorFlowTestCase):
def testTracebackWithStartLines(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
a = constant_op.constant(2.0)
sess.run(
a,