aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py')
-rw-r--r--tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py13
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py
index 60c429eefe..2598af4b27 100644
--- a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py
+++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py
@@ -31,6 +31,11 @@ class ZeroOut2Test(tf.test.TestCase):
result = zero_out_op_2.zero_out([5, 4, 3, 2, 1])
self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0])
+ def test_2d(self):
+ with self.test_session():
+ result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]])
+ self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]])
+
def test_grad(self):
with self.test_session():
shape = (5,)
@@ -39,6 +44,14 @@ class ZeroOut2Test(tf.test.TestCase):
err = tf.test.compute_gradient_error(x, shape, y, shape)
self.assertLess(err, 1e-4)
+ def test_grad_2d(self):
+ with self.test_session():
+ shape = (2, 3)
+ x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32)
+ y = zero_out_op_2.zero_out(x)
+ err = tf.test.compute_gradient_error(x, shape, y, shape)
+ self.assertLess(err, 1e-4)
+
if __name__ == '__main__':
tf.test.main()