diff options
Diffstat (limited to 'tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py')
-rw-r--r-- | tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py index 60c429eefe..2598af4b27 100644 --- a/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py +++ b/tensorflow/g3doc/how_tos/adding_an_op/zero_out_2_test.py @@ -31,6 +31,11 @@ class ZeroOut2Test(tf.test.TestCase): result = zero_out_op_2.zero_out([5, 4, 3, 2, 1]) self.assertAllEqual(result.eval(), [5, 0, 0, 0, 0]) + def test_2d(self): + with self.test_session(): + result = zero_out_op_2.zero_out([[6, 5, 4], [3, 2, 1]]) + self.assertAllEqual(result.eval(), [[6, 0, 0], [0, 0, 0]]) + def test_grad(self): with self.test_session(): shape = (5,) @@ -39,6 +44,14 @@ class ZeroOut2Test(tf.test.TestCase): err = tf.test.compute_gradient_error(x, shape, y, shape) self.assertLess(err, 1e-4) + def test_grad_2d(self): + with self.test_session(): + shape = (2, 3) + x = tf.constant([[6, 5, 4], [3, 2, 1]], dtype=tf.float32) + y = zero_out_op_2.zero_out(x) + err = tf.test.compute_gradient_error(x, shape, y, shape) + self.assertLess(err, 1e-4) + if __name__ == '__main__': tf.test.main() |