diff options
author | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-08-09 19:10:33 +0800 |
---|---|---|
committer | Yan Facai (颜发才) <facai.yan@gmail.com> | 2018-08-09 19:21:17 +0800 |
commit | e7844a664bf999946b4da022cdf2a538c2a244ae (patch) | |
tree | 2a624fbd98330b9d24feef687646a1f2c8cf8603 /tensorflow/python/kernel_tests/clip_ops_test.py | |
parent | e1b825ded7585f5ab83634ebaa7c0b15ad787fc5 (diff) |
ENH: use assertion, use_norm must be finite
Diffstat (limited to 'tensorflow/python/kernel_tests/clip_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/clip_ops_test.py | 15 |
1 files changed, 4 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py index dacb7bd8df..0de953f465 100644 --- a/tensorflow/python/kernel_tests/clip_ops_test.py +++ b/tensorflow/python/kernel_tests/clip_ops_test.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -370,23 +371,15 @@ class ClipTest(test.TestCase): self.assertAllClose(np_ans_1, tf_ans_2) def testClipByGlobalNormInf(self): - # Norm = inf, return NaN with self.test_session(use_gpu=True): x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0], shape=[2, 3]) x1 = constant_op.constant([1.0, -2.0]) - np_ans_0 = [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]] - np_ans_1 = [np.nan, np.nan] clip_norm = 6.0 - ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm) - tf_ans_1 = ans[0].eval() - tf_ans_2 = ans[1].eval() - tf_norm = norm.eval() - - self.assertAllClose(tf_norm, np.inf) - self.assertAllClose(np_ans_0, tf_ans_1) - self.assertAllClose(np_ans_1, tf_ans_2) + _, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm) + with self.assertRaisesRegexp(errors.InvalidArgumentError, "use_norm"): + norm.eval() def testClipByAverageNormClipped(self): # Norm clipping when average clip_norm < 0.83333333 |