aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/clip_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/clip_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index fb52d10475..400d38b936 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -369,6 +370,21 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans_0, tf_ans_1)
self.assertAllClose(np_ans_1, tf_ans_2)
+ def testClipByGlobalNormInf(self):
+ with self.test_session(use_gpu=True):
+ x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0],
+ shape=[2, 3])
+ x1 = constant_op.constant([1.0, -2.0])
+ clip_norm = 6.0
+
+ ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ norm.eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ ans[0].eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ ans[1].eval()
+
def testClipByAverageNormClipped(self):
# Norm clipping when average clip_norm < 0.83333333
with self.test_session(use_gpu=True):