diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/clip_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/clip_ops_test.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py index fb52d10475..400d38b936 100644 --- a/tensorflow/python/kernel_tests/clip_ops_test.py +++ b/tensorflow/python/kernel_tests/clip_ops_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -369,6 +370,21 @@ class ClipTest(test.TestCase): self.assertAllClose(np_ans_0, tf_ans_1) self.assertAllClose(np_ans_1, tf_ans_2) + def testClipByGlobalNormInf(self): + with self.test_session(use_gpu=True): + x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0], + shape=[2, 3]) + x1 = constant_op.constant([1.0, -2.0]) + clip_norm = 6.0 + + ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm) + with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"): + norm.eval() + with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"): + ans[0].eval() + with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"): + ans[1].eval() + def testClipByAverageNormClipped(self): # Norm clipping when average clip_norm < 0.83333333 with self.test_session(use_gpu=True): |