aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2017-03-22 09:21:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-22 10:48:37 -0700
commitb35a22c36a2500bf00d14fab16276ba75cc7af2f (patch)
tree4a9c77fb444263f8b1135798a425d9e8e4992461
parent53bf26653df065fe4ca242aaf893b18a60c7f80b (diff)
tf.clip_by_value: prevent unintentional broadcasting
where the output is not compatible with the input. The clipping should never produce a shape that is different than the input. We can catch such cases by using TensorShape's merge to validate (when statically known) when this accidental broadcasting is happening, though we have no way of preventing this when the user chooses dynamic batch sizes (e.g., None's). Adds a test for a test case that shouldn't pass, which did before. RELNOTES: clip_by_value and clip_by_norm contract tightened to ensure output shape matches input shape (no accidental broadcasting). This is a bugfix that may break models that have a bug in them (clipping ops by definition should not change the shape of the input tensor). Change: 150893660
-rw-r--r--tensorflow/python/kernel_tests/clip_ops_test.py18
-rw-r--r--tensorflow/python/ops/clip_ops.py15
2 files changed, 32 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py
index bbd1ab46ae..5c8b71da17 100644
--- a/tensorflow/python/kernel_tests/clip_ops_test.py
+++ b/tensorflow/python/kernel_tests/clip_ops_test.py
@@ -37,6 +37,16 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
+ def testClipByValueBadShape(self):
+ with self.test_session():
+ x = constant_op.constant([-5.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3, 1])
+ # Use a nonsensical shape.
+ clip = constant_op.constant([1.0, 2.0])
+ with self.assertRaises(ValueError):
+ _ = clip_ops.clip_by_value(x, -clip, clip)
+ with self.assertRaises(ValueError):
+ _ = clip_ops.clip_by_value(x, 1.0, clip)
+
def testClipByValueNonFinite(self):
with self.test_session():
x = constant_op.constant([float('NaN'), float('Inf'), -float('Inf')])
@@ -65,6 +75,14 @@ class ClipTest(test.TestCase):
self.assertAllClose(np_ans, tf_ans)
self.assertAllClose(np_ans, tf_ans_tensor)
+ def testClipByNormBadShape(self):
+ with self.test_session():
+ x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3, 1])
+ # Use a nonsensical shape.
+ clip = constant_op.constant([1.0, 2.0])
+ with self.assertRaises(ValueError):
+ _ = clip_ops.clip_by_norm(x, clip)
+
def testClipByNormNotClipped(self):
# No norm clipping when clip_norm >= 5
with self.test_session():
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 3dc0ac34c8..7430c28583 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -49,6 +49,10 @@ def clip_by_value(t, clip_value_min, clip_value_max,
Returns:
A clipped `Tensor`.
+
+ Raises:
+ ValueError: if the clip tensors would trigger array broadcasting
+ that would make the returned tensor larger than the input.
"""
with ops.name_scope(name, "clip_by_value",
[t, clip_value_min, clip_value_max]) as name:
@@ -56,7 +60,12 @@ def clip_by_value(t, clip_value_min, clip_value_max,
# Go through list of tensors, for each value in each tensor clip
t_min = math_ops.minimum(t, clip_value_max)
+ # Assert that the shape is compatible with the initial shape,
+ # to prevent unintentional broadcasting.
+ _ = t.shape.merge_with(t_min.shape)
+
t_max = math_ops.maximum(t_min, clip_value_min, name=name)
+ _ = t.shape.merge_with(t_max.shape)
return t_max
@@ -100,7 +109,11 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
# Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm
l2norm_inv = math_ops.rsqrt(
math_ops.reduce_sum(t * t, axes, keep_dims=True))
- tclip = array_ops.identity(t * clip_norm * math_ops.minimum(
+ intermediate = t * clip_norm
+ # Assert that the shape is compatible with the initial shape,
+ # to prevent unintentional broadcasting.
+ _ = t.shape.merge_with(intermediate.shape)
+ tclip = array_ops.identity(intermediate * math_ops.minimum(
l2norm_inv, constant_op.constant(1.0, dtype=t.dtype) / clip_norm),
name=name)