aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/clip_ops_test.py
diff options
context:
space:
mode:
authorGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-09 19:10:33 +0800
committerGravatar Yan Facai (颜发才) <facai.yan@gmail.com>2018-08-09 19:21:17 +0800
commite7844a664bf999946b4da022cdf2a538c2a244ae (patch)
tree2a624fbd98330b9d24feef687646a1f2c8cf8603 /tensorflow/python/kernel_tests/clip_ops_test.py
parente1b825ded7585f5ab83634ebaa7c0b15ad787fc5 (diff)
ENH: use assertion, use_norm must be finite
Diffstat (limited to 'tensorflow/python/kernel_tests/clip_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py15
1 files changed, 4 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index dacb7bd8df..0de953f465 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -370,23 +371,15 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans_1, tf_ans_2)
def testClipByGlobalNormInf(self):
- # Norm = inf, return NaN
with self.test_session(use_gpu=True):
x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0],
shape=[2, 3])
x1 = constant_op.constant([1.0, -2.0])
- np_ans_0 = [[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan]]
- np_ans_1 = [np.nan, np.nan]
clip_norm = 6.0
- ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm)
- tf_ans_1 = ans[0].eval()
- tf_ans_2 = ans[1].eval()
- tf_norm = norm.eval()
-
- self.assertAllClose(tf_norm, np.inf)
- self.assertAllClose(np_ans_0, tf_ans_1)
- self.assertAllClose(np_ans_1, tf_ans_2)
+ _, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "use_norm"):
+ norm.eval()
def testClipByAverageNormClipped(self):
# Norm clipping when average clip_norm < 0.83333333