diff options
author | 2016-12-21 11:36:38 -0800 | |
---|---|---|
committer | 2016-12-21 11:52:05 -0800 | |
commit | 4a7233145c762253d3f0ada1750e8a02e33a29c3 (patch) | |
tree | 373ffac25dd07110ad99e2c29419b24535bc1a67 | |
parent | 5bf28a164e58097c06fc45362bc4a88626a2c988 (diff) |
Add `assert_rank_in`, to handle cases where target ranks are a list, not an upper or lower bound.
Change: 142684404
-rw-r--r-- | tensorflow/python/kernel_tests/check_ops_test.py | 114 | ||||
-rw-r--r-- | tensorflow/python/ops/check_ops.py | 137 |
2 files changed, 244 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index cdfbbbdaf2..a2df4cb2a7 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -554,6 +554,120 @@ class AssertRankTest(test.TestCase): array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5}) +class AssertRankInTest(test.TestCase): + + def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self): + with self.test_session(): + tensor_rank0 = constant_op.constant(42, name="my_tensor") + with self.assertRaisesRegexp( + ValueError, "fail.*my_tensor.*must have rank.*in.*1.*2"): + with ops.control_dependencies([ + check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]): + array_ops.identity(tensor_rank0).eval() + + def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self): + with self.test_session(): + tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") + with ops.control_dependencies([ + check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]): + with self.assertRaisesOpError("fail.*my_tensor.*rank"): + array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) + + def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self): + with self.test_session(): + tensor_rank0 = constant_op.constant(42, name="my_tensor") + for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): + with ops.control_dependencies([ + check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): + array_ops.identity(tensor_rank0).eval() + + def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): + with self.test_session(): + tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor") + for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): + with ops.control_dependencies([ + check_ops.assert_rank_in(tensor_rank0, desired_ranks)]): + array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0}) + + def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self): + with self.test_session(): + tensor_rank1 = constant_op.constant([42, 43], name="my_tensor") + for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): + with ops.control_dependencies([ + check_ops.assert_rank_in(tensor_rank1, desired_ranks)]): + array_ops.identity(tensor_rank1).eval() + + def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self): + with self.test_session(): + tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") + for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)): + with ops.control_dependencies([ + check_ops.assert_rank_in(tensor_rank1, desired_ranks)]): + array_ops.identity(tensor_rank1).eval(feed_dict={ + tensor_rank1: (42.0, 43.0) + }) + + def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self): + with self.test_session(): + tensor_rank1 = constant_op.constant((42, 43), name="my_tensor") + with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"): + with ops.control_dependencies([ + check_ops.assert_rank_in(tensor_rank1, (0, 2))]): + array_ops.identity(tensor_rank1).eval() + + def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self): + with self.test_session(): + tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor") + with ops.control_dependencies([ + check_ops.assert_rank_in(tensor_rank1, (0, 2))]): + with self.assertRaisesOpError("my_tensor.*rank"): + array_ops.identity(tensor_rank1).eval(feed_dict={ + tensor_rank1: (42.0, 43.0) + }) + + def test_raises_if_rank_is_not_scalar_static(self): + with self.test_session(): + tensor = constant_op.constant((42, 43), name="my_tensor") + desired_ranks = ( + np.array(1, dtype=np.int32), + np.array((2, 1), dtype=np.int32)) + with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"): + check_ops.assert_rank_in(tensor, desired_ranks) + + def test_raises_if_rank_is_not_scalar_dynamic(self): + with self.test_session(): + tensor = constant_op.constant( + (42, 43), dtype=dtypes.float32, name="my_tensor") + desired_ranks = ( + array_ops.placeholder(dtypes.int32, name="rank0_tensor"), + array_ops.placeholder(dtypes.int32, name="rank1_tensor")) + with self.assertRaisesOpError("Rank must be a scalar"): + with ops.control_dependencies( + (check_ops.assert_rank_in(tensor, desired_ranks),)): + array_ops.identity(tensor).eval(feed_dict={ + desired_ranks[0]: 1, + desired_ranks[1]: [2, 1], + }) + + def test_raises_if_rank_is_not_integer_static(self): + with self.test_session(): + tensor = constant_op.constant((42, 43), name="my_tensor") + with self.assertRaisesRegexp(TypeError, + "must be of type <dtype: 'int32'>"): + check_ops.assert_rank_in(tensor, (1, .5,)) + + def test_raises_if_rank_is_not_integer_dynamic(self): + with self.test_session(): + tensor = constant_op.constant( + (42, 43), dtype=dtypes.float32, name="my_tensor") + rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor") + with self.assertRaisesRegexp(TypeError, + "must be of type <dtype: 'int32'>"): + with ops.control_dependencies( + [check_ops.assert_rank_in(tensor, (1, rank_tensor))]): + array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5}) + + class AssertRankAtLeastTest(test.TestCase): def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self): diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index b50153dcc6..0e088649cd 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -68,6 +68,7 @@ __all__ = [ 'assert_greater_equal', 'assert_rank', 'assert_rank_at_least', + 'assert_rank_in', 'assert_type', 'is_non_decreasing', 'is_numeric_tensor', @@ -465,16 +466,15 @@ def _assert_rank_condition( Raises: ValueError: If static checks determine `x` fails static_condition. """ - # Attempt to statically defined rank. - x_rank_static = x.get_shape().ndims - rank_static = tensor_util.constant_value(rank) - assert_type(rank, dtypes.int32) + # Attempt to statically defined rank. + rank_static = tensor_util.constant_value(rank) if rank_static is not None: if rank_static.ndim != 0: - raise ValueError('Rank must be a scalar') + raise ValueError('Rank must be a scalar.') + x_rank_static = x.get_shape().ndims if x_rank_static is not None: if not static_condition(x_rank_static, rank_static): raise ValueError( @@ -519,7 +519,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): Raises: ValueError: If static checks determine `x` has wrong rank. """ - with ops.name_scope(name, 'assert_rank', [x]): + with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): x = ops.convert_to_tensor(x, name='x') rank = ops.convert_to_tensor(rank, name='rank') message = message or '' @@ -577,7 +577,8 @@ def assert_rank_at_least( Raises: ValueError: If static checks determine `x` has wrong rank. """ - with ops.name_scope(name, 'assert_rank_at_least', [x]): + with ops.name_scope( + name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): x = ops.convert_to_tensor(x, name='x') rank = ops.convert_to_tensor(rank, name='rank') message = message or '' @@ -606,6 +607,128 @@ def assert_rank_at_least( return assert_op +def _static_rank_in(actual_rank, given_ranks): + return actual_rank in given_ranks + + +def _dynamic_rank_in(actual_rank, given_ranks): + if len(given_ranks) < 1: + return ops.convert_to_tensor(False) + result = math_ops.equal(given_ranks[0], actual_rank) + for given_rank in given_ranks[1:]: + result = math_ops.logical_or( + result, math_ops.equal(given_rank, actual_rank)) + return result + + +def _assert_ranks_condition( + x, ranks, static_condition, dynamic_condition, data, summarize): + """Assert `x` has a rank that satisfies a given condition. + + Args: + x: Numeric `Tensor`. + ranks: Scalar `Tensor`. + static_condition: A python function that takes + `[actual_rank, given_ranks]` and returns `True` if the condition is + satisfied, `False` otherwise. + dynamic_condition: An `op` that takes [actual_rank, given_ranks] + and return `True` if the condition is satisfied, `False` otherwise. + data: The tensors to print out if the condition is false. Defaults to + error message and first few entries of `x`. + summarize: Print this many entries of each tensor. + + Returns: + Op raising `InvalidArgumentError` if `x` fails dynamic_condition. + + Raises: + ValueError: If static checks determine `x` fails static_condition. + """ + for rank in ranks: + assert_type(rank, dtypes.int32) + + # Attempt to statically defined rank. + ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) + if None not in ranks_static: + for rank_static in ranks_static: + if rank_static.ndim != 0: + raise ValueError('Rank must be a scalar.') + + x_rank_static = x.get_shape().ndims + if x_rank_static is not None: + if not static_condition(x_rank_static, ranks_static): + raise ValueError( + 'Static rank condition failed', x_rank_static, ranks_static) + return control_flow_ops.no_op(name='static_checks_determined_all_ok') + + condition = dynamic_condition(array_ops.rank(x), ranks) + + # Add the condition that `rank` must have rank zero. Prevents the bug where + # someone does assert_rank(x, [n]), rather than assert_rank(x, n). + for rank, rank_static in zip(ranks, ranks_static): + if rank_static is None: + this_data = ['Rank must be a scalar. Received rank: ', rank] + rank_check = assert_rank(rank, 0, data=this_data) + condition = control_flow_ops.with_dependencies([rank_check], condition) + + return control_flow_ops.Assert(condition, data, summarize=summarize) + + +def assert_rank_in( + x, ranks, data=None, summarize=None, message=None, name=None): + """Assert `x` has rank in `ranks`. + + Example of adding a dependency to an operation: + + ```python + with tf.control_dependencies([tf.assert_rank_in(x, (2, 4))]): + output = tf.reduce_sum(x) + ``` + + Args: + x: Numeric `Tensor`. + ranks: Iterable of scalar `Tensor` objects. + data: The tensors to print out if the condition is False. Defaults to + error message and first few entries of `x`. + summarize: Print this many entries of each tensor. + message: A string to prefix to the default message. + name: A name for this operation (optional). + Defaults to "assert_rank_in". + + Returns: + Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. + If static checks determine `x` has matching rank, a `no_op` is returned. + + Raises: + ValueError: If static checks determine `x` has mismatched rank. + """ + with ops.name_scope( + name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): + x = ops.convert_to_tensor(x, name='x') + ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) + message = message or '' + + if data is None: + data = [ + message, 'Tensor %s must have rank in' % x.name + ] + list(ranks) + [ + 'Received shape: ', array_ops.shape(x) + ] + + try: + assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, + _dynamic_rank_in, data, summarize) + + except ValueError as e: + if e.args[0] == 'Static rank condition failed': + raise ValueError( + '%s. Tensor %s must have rank in %s. Received rank %d, ' + 'shape %s' % (message, x.name, e.args[2], e.args[1], x.get_shape())) + else: + raise + + return assert_op + + def assert_integer(x, message=None, name=None): """Assert that `x` is of integer dtype. |