diff options
author | 2017-03-22 09:21:48 -0800 | |
---|---|---|
committer | 2017-03-22 10:48:37 -0700 | |
commit | b35a22c36a2500bf00d14fab16276ba75cc7af2f (patch) | |
tree | 4a9c77fb444263f8b1135798a425d9e8e4992461 | |
parent | 53bf26653df065fe4ca242aaf893b18a60c7f80b (diff) |
tf.clip_by_value: prevent unintentional broadcasting
where the output is not compatible with the input.
The clipping should never produce a shape that is different than
the input. We can catch such cases by using TensorShape's merge
to validate (when statically known) when this accidental broadcasting
is happening, though we have no way of preventing this when the user
chooses dynamic batch sizes (e.g., None's).
Adds a test for a test case that shouldn't pass, which did before.
RELNOTES: clip_by_value and clip_by_norm contract tightened to ensure
output shape matches input shape (no accidental broadcasting). This is
a bugfix that may break models that have a bug in them (clipping ops
by definition should not change the shape of the input tensor).
Change: 150893660
-rw-r--r-- | tensorflow/python/kernel_tests/clip_ops_test.py | 18 | ||||
-rw-r--r-- | tensorflow/python/ops/clip_ops.py | 15 |
2 files changed, 32 insertions, 1 deletions
diff --git a/tensorflow/python/kernel_tests/clip_ops_test.py b/tensorflow/python/kernel_tests/clip_ops_test.py index bbd1ab46ae..5c8b71da17 100644 --- a/tensorflow/python/kernel_tests/clip_ops_test.py +++ b/tensorflow/python/kernel_tests/clip_ops_test.py @@ -37,6 +37,16 @@ class ClipTest(test.TestCase): self.assertAllClose(np_ans, tf_ans) + def testClipByValueBadShape(self): + with self.test_session(): + x = constant_op.constant([-5.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3, 1]) + # Use a nonsensical shape. + clip = constant_op.constant([1.0, 2.0]) + with self.assertRaises(ValueError): + _ = clip_ops.clip_by_value(x, -clip, clip) + with self.assertRaises(ValueError): + _ = clip_ops.clip_by_value(x, 1.0, clip) + def testClipByValueNonFinite(self): with self.test_session(): x = constant_op.constant([float('NaN'), float('Inf'), -float('Inf')]) @@ -65,6 +75,14 @@ class ClipTest(test.TestCase): self.assertAllClose(np_ans, tf_ans) self.assertAllClose(np_ans, tf_ans_tensor) + def testClipByNormBadShape(self): + with self.test_session(): + x = constant_op.constant([-3.0, 0.0, 0.0, 4.0, 0.0, 0.0], shape=[2, 3, 1]) + # Use a nonsensical shape. + clip = constant_op.constant([1.0, 2.0]) + with self.assertRaises(ValueError): + _ = clip_ops.clip_by_norm(x, clip) + def testClipByNormNotClipped(self): # No norm clipping when clip_norm >= 5 with self.test_session(): diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py index 3dc0ac34c8..7430c28583 100644 --- a/tensorflow/python/ops/clip_ops.py +++ b/tensorflow/python/ops/clip_ops.py @@ -49,6 +49,10 @@ def clip_by_value(t, clip_value_min, clip_value_max, Returns: A clipped `Tensor`. + + Raises: + ValueError: if the clip tensors would trigger array broadcasting + that would make the returned tensor larger than the input. """ with ops.name_scope(name, "clip_by_value", [t, clip_value_min, clip_value_max]) as name: @@ -56,7 +60,12 @@ def clip_by_value(t, clip_value_min, clip_value_max, # Go through list of tensors, for each value in each tensor clip t_min = math_ops.minimum(t, clip_value_max) + # Assert that the shape is compatible with the initial shape, + # to prevent unintentional broadcasting. + _ = t.shape.merge_with(t_min.shape) + t_max = math_ops.maximum(t_min, clip_value_min, name=name) + _ = t.shape.merge_with(t_max.shape) return t_max @@ -100,7 +109,11 @@ def clip_by_norm(t, clip_norm, axes=None, name=None): # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm l2norm_inv = math_ops.rsqrt( math_ops.reduce_sum(t * t, axes, keep_dims=True)) - tclip = array_ops.identity(t * clip_norm * math_ops.minimum( + intermediate = t * clip_norm + # Assert that the shape is compatible with the initial shape, + # to prevent unintentional broadcasting. + _ = t.shape.merge_with(intermediate.shape) + tclip = array_ops.identity(intermediate * math_ops.minimum( l2norm_inv, constant_op.constant(1.0, dtype=t.dtype) / clip_norm), name=name) |