aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/function_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/function_test.py')
-rw-r--r--tensorflow/python/framework/function_test.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index ee723bacaf..903768a039 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -419,7 +419,7 @@ class FunctionTest(test.TestCase):
with ops.control_dependencies([z]):
return x * 2
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
z = Foo(constant_op.constant(3.0))
self.assertAllEqual(z.eval(), 6.0)
@@ -434,7 +434,7 @@ class FunctionTest(test.TestCase):
# Foo contains a stateful op (Assert).
self.assertEqual([("Assert", "Assert")], Foo.stateful_ops)
g = ops.Graph()
- with g.as_default(), self.test_session():
+ with g.as_default(), self.cached_session():
self.assertAllEqual(Foo(constant_op.constant(3.0)).eval(), 6.0)
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion failed.*-3"):
@@ -448,7 +448,7 @@ class FunctionTest(test.TestCase):
[control_flow_ops.Assert(math_ops.less_equal(x, 10.0), [x])]):
return array_ops.identity(x)
- with self.test_session():
+ with self.cached_session():
self.assertEqual(1.0, MyFn(1.0).eval())
with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"assertion"):
@@ -667,7 +667,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testNestedDefinedFunction(self):
@@ -683,7 +683,7 @@ class FunctionTest(test.TestCase):
with ops.Graph().as_default():
z = CubeXPlusY(3.0, -2.0)
- with self.test_session():
+ with self.cached_session():
self.assertAllEqual(z.eval(), 25.0)
def testUnusedFunction(self):