aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-12 16:10:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-12 16:10:42 -0700
commit087190246b95dc4c188f630ca90880a12e39b557 (patch)
tree3e687f9ea89fc7072b954716f16a5e7b2a79d3b1
parent1a22b0b982fa1a953651b98af8f3cd30542048fd (diff)
parenta170aef5b6703ec7d4819aaadc4bcd9c8f6cb017 (diff)
Merge pull request #21428 from facaiy:BUG/clip_by_global_norm_with_inf
PiperOrigin-RevId: 208412584
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py16
-rw-r--r--tensorflow/python/ops/clip_ops.py6
2 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index fb52d10475..400d38b936 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -369,6 +370,21 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans_0, tf_ans_1)
self.assertAllClose(np_ans_1, tf_ans_2)
+ def testClipByGlobalNormInf(self):
+ with self.test_session(use_gpu=True):
+ x0 = constant_op.constant([-2.0, 0.0, np.inf, 4.0, 0.0, 0.0],
+ shape=[2, 3])
+ x1 = constant_op.constant([1.0, -2.0])
+ clip_norm = 6.0
+
+ ans, norm = clip_ops.clip_by_global_norm([x0, x1], clip_norm)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ norm.eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ ans[0].eval()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "global norm"):
+ ans[1].eval()
+
def testClipByAverageNormClipped(self):
# Norm clipping when average clip_norm < 0.83333333
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index e2580e8a2e..78b395a6c1 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import numerics
from tensorflow.python.util.tf_export import tf_export
@@ -57,7 +58,7 @@ def clip_by_value(t, clip_value_min, clip_value_max,
A clipped `Tensor`.
Raises:
- ValueError: if the clip tensors would trigger array broadcasting
+ ValueError: If the clip tensors would trigger array broadcasting
that would make the returned tensor larger than the input.
"""
with ops.name_scope(name, "clip_by_value",
@@ -246,6 +247,7 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
Raises:
TypeError: If `t_list` is not a sequence.
+ InvalidArgumentError: If global norm is not finite.
"""
if (not isinstance(t_list, collections.Sequence)
or isinstance(t_list, six.string_types)):
@@ -253,6 +255,8 @@ def clip_by_global_norm(t_list, clip_norm, use_norm=None, name=None):
t_list = list(t_list)
if use_norm is None:
use_norm = global_norm(t_list, name)
+ use_norm = numerics.verify_tensor_all_finite(use_norm,
+ "Found Inf or NaN global norm.")
with ops.name_scope(name, "clip_by_global_norm",
t_list + [clip_norm]) as name: