aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/adding_an_op/zero_out_2_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples/adding_an_op/zero_out_2_test.py')
-rw-r--r--tensorflow/examples/adding_an_op/zero_out_2_test.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/examples/adding_an_op/zero_out_2_test.py b/tensorflow/examples/adding_an_op/zero_out_2_test.py
index 217bbbcffa..4504597817 100644
--- a/tensorflow/examples/adding_an_op/zero_out_2_test.py
+++ b/tensorflow/examples/adding_an_op/zero_out_2_test.py
@@ -29,17 +29,17 @@ from tensorflow.examples.adding_an_op import zero_out_op_2
class ZeroOut2Test(tf.test.TestCase):
def test(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
def test_2d(self):
- with self.test_session():
+ with self.cached_session():
result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
def test_grad(self):
- with self.test_session():
+ with self.cached_session():
shape = (5,)
x = tf.constant([5, 4, 3, 2, 1], dtype=tf.float32)
y = zero_out_op_2.zero_out(x)
@@ -47,7 +47,7 @@ class ZeroOut2Test(tf.test.TestCase):
self.assertLess(err, 1e-4)
def test_grad_2d(self):
- with self.test_session():
+ with self.cached_session():
shape = (2, 3)
x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32)
y = zero_out_op_2.zero_out(x)