aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-21 11:36:38 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-21 11:52:05 -0800
commit4a7233145c762253d3f0ada1750e8a02e33a29c3 (patch)
tree373ffac25dd07110ad99e2c29419b24535bc1a67
parent5bf28a164e58097c06fc45362bc4a88626a2c988 (diff)
Add `assert_rank_in`, to handle cases where target ranks are a list, not an upper or lower bound.
Change: 142684404
-rw-r--r--tensorflow/python/kernel_tests/check_ops_test.py114
-rw-r--r--tensorflow/python/ops/check_ops.py137
2 files changed, 244 insertions, 7 deletions
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index cdfbbbdaf2..a2df4cb2a7 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -554,6 +554,120 @@ class AssertRankTest(test.TestCase):
array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})
+class AssertRankInTest(test.TestCase):
+
+ def test_rank_zero_tensor_raises_if_rank_mismatch_static_rank(self):
+ with self.test_session():
+ tensor_rank0 = constant_op.constant(42, name="my_tensor")
+ with self.assertRaisesRegexp(
+ ValueError, "fail.*my_tensor.*must have rank.*in.*1.*2"):
+ with ops.control_dependencies([
+ check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
+ array_ops.identity(tensor_rank0).eval()
+
+ def test_rank_zero_tensor_raises_if_rank_mismatch_dynamic_rank(self):
+ with self.test_session():
+ tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
+ with ops.control_dependencies([
+ check_ops.assert_rank_in(tensor_rank0, (1, 2), message="fail")]):
+ with self.assertRaisesOpError("fail.*my_tensor.*rank"):
+ array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
+
+ def test_rank_zero_tensor_doesnt_raise_if_rank_matches_static_rank(self):
+ with self.test_session():
+ tensor_rank0 = constant_op.constant(42, name="my_tensor")
+ for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
+ with ops.control_dependencies([
+ check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
+ array_ops.identity(tensor_rank0).eval()
+
+ def test_rank_zero_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
+ with self.test_session():
+ tensor_rank0 = array_ops.placeholder(dtypes.float32, name="my_tensor")
+ for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
+ with ops.control_dependencies([
+ check_ops.assert_rank_in(tensor_rank0, desired_ranks)]):
+ array_ops.identity(tensor_rank0).eval(feed_dict={tensor_rank0: 42.0})
+
+ def test_rank_one_tensor_doesnt_raise_if_rank_matches_static_rank(self):
+ with self.test_session():
+ tensor_rank1 = constant_op.constant([42, 43], name="my_tensor")
+ for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
+ with ops.control_dependencies([
+ check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
+ array_ops.identity(tensor_rank1).eval()
+
+ def test_rank_one_tensor_doesnt_raise_if_rank_matches_dynamic_rank(self):
+ with self.test_session():
+ tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
+ for desired_ranks in ((0, 1, 2), (1, 0, 2), (1, 2, 0)):
+ with ops.control_dependencies([
+ check_ops.assert_rank_in(tensor_rank1, desired_ranks)]):
+ array_ops.identity(tensor_rank1).eval(feed_dict={
+ tensor_rank1: (42.0, 43.0)
+ })
+
+ def test_rank_one_tensor_raises_if_rank_mismatches_static_rank(self):
+ with self.test_session():
+ tensor_rank1 = constant_op.constant((42, 43), name="my_tensor")
+ with self.assertRaisesRegexp(ValueError, "my_tensor.*rank"):
+ with ops.control_dependencies([
+ check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
+ array_ops.identity(tensor_rank1).eval()
+
+ def test_rank_one_tensor_raises_if_rank_mismatches_dynamic_rank(self):
+ with self.test_session():
+ tensor_rank1 = array_ops.placeholder(dtypes.float32, name="my_tensor")
+ with ops.control_dependencies([
+ check_ops.assert_rank_in(tensor_rank1, (0, 2))]):
+ with self.assertRaisesOpError("my_tensor.*rank"):
+ array_ops.identity(tensor_rank1).eval(feed_dict={
+ tensor_rank1: (42.0, 43.0)
+ })
+
+ def test_raises_if_rank_is_not_scalar_static(self):
+ with self.test_session():
+ tensor = constant_op.constant((42, 43), name="my_tensor")
+ desired_ranks = (
+ np.array(1, dtype=np.int32),
+ np.array((2, 1), dtype=np.int32))
+ with self.assertRaisesRegexp(ValueError, "Rank must be a scalar"):
+ check_ops.assert_rank_in(tensor, desired_ranks)
+
+ def test_raises_if_rank_is_not_scalar_dynamic(self):
+ with self.test_session():
+ tensor = constant_op.constant(
+ (42, 43), dtype=dtypes.float32, name="my_tensor")
+ desired_ranks = (
+ array_ops.placeholder(dtypes.int32, name="rank0_tensor"),
+ array_ops.placeholder(dtypes.int32, name="rank1_tensor"))
+ with self.assertRaisesOpError("Rank must be a scalar"):
+ with ops.control_dependencies(
+ (check_ops.assert_rank_in(tensor, desired_ranks),)):
+ array_ops.identity(tensor).eval(feed_dict={
+ desired_ranks[0]: 1,
+ desired_ranks[1]: [2, 1],
+ })
+
+ def test_raises_if_rank_is_not_integer_static(self):
+ with self.test_session():
+ tensor = constant_op.constant((42, 43), name="my_tensor")
+ with self.assertRaisesRegexp(TypeError,
+ "must be of type <dtype: 'int32'>"):
+ check_ops.assert_rank_in(tensor, (1, .5,))
+
+ def test_raises_if_rank_is_not_integer_dynamic(self):
+ with self.test_session():
+ tensor = constant_op.constant(
+ (42, 43), dtype=dtypes.float32, name="my_tensor")
+ rank_tensor = array_ops.placeholder(dtypes.float32, name="rank_tensor")
+ with self.assertRaisesRegexp(TypeError,
+ "must be of type <dtype: 'int32'>"):
+ with ops.control_dependencies(
+ [check_ops.assert_rank_in(tensor, (1, rank_tensor))]):
+ array_ops.identity(tensor).eval(feed_dict={rank_tensor: .5})
+
+
class AssertRankAtLeastTest(test.TestCase):
def test_rank_zero_tensor_raises_if_rank_too_small_static_rank(self):
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index b50153dcc6..0e088649cd 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -68,6 +68,7 @@ __all__ = [
'assert_greater_equal',
'assert_rank',
'assert_rank_at_least',
+ 'assert_rank_in',
'assert_type',
'is_non_decreasing',
'is_numeric_tensor',
@@ -465,16 +466,15 @@ def _assert_rank_condition(
Raises:
ValueError: If static checks determine `x` fails static_condition.
"""
- # Attempt to statically defined rank.
- x_rank_static = x.get_shape().ndims
- rank_static = tensor_util.constant_value(rank)
-
assert_type(rank, dtypes.int32)
+ # Attempt to statically defined rank.
+ rank_static = tensor_util.constant_value(rank)
if rank_static is not None:
if rank_static.ndim != 0:
- raise ValueError('Rank must be a scalar')
+ raise ValueError('Rank must be a scalar.')
+ x_rank_static = x.get_shape().ndims
if x_rank_static is not None:
if not static_condition(x_rank_static, rank_static):
raise ValueError(
@@ -519,7 +519,7 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
Raises:
ValueError: If static checks determine `x` has wrong rank.
"""
- with ops.name_scope(name, 'assert_rank', [x]):
+ with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
message = message or ''
@@ -577,7 +577,8 @@ def assert_rank_at_least(
Raises:
ValueError: If static checks determine `x` has wrong rank.
"""
- with ops.name_scope(name, 'assert_rank_at_least', [x]):
+ with ops.name_scope(
+ name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
x = ops.convert_to_tensor(x, name='x')
rank = ops.convert_to_tensor(rank, name='rank')
message = message or ''
@@ -606,6 +607,128 @@ def assert_rank_at_least(
return assert_op
+def _static_rank_in(actual_rank, given_ranks):
+ return actual_rank in given_ranks
+
+
+def _dynamic_rank_in(actual_rank, given_ranks):
+ if len(given_ranks) < 1:
+ return ops.convert_to_tensor(False)
+ result = math_ops.equal(given_ranks[0], actual_rank)
+ for given_rank in given_ranks[1:]:
+ result = math_ops.logical_or(
+ result, math_ops.equal(given_rank, actual_rank))
+ return result
+
+
+def _assert_ranks_condition(
+ x, ranks, static_condition, dynamic_condition, data, summarize):
+ """Assert `x` has a rank that satisfies a given condition.
+
+ Args:
+ x: Numeric `Tensor`.
+ ranks: Scalar `Tensor`.
+ static_condition: A python function that takes
+ `[actual_rank, given_ranks]` and returns `True` if the condition is
+ satisfied, `False` otherwise.
+ dynamic_condition: An `op` that takes [actual_rank, given_ranks]
+ and return `True` if the condition is satisfied, `False` otherwise.
+ data: The tensors to print out if the condition is false. Defaults to
+ error message and first few entries of `x`.
+ summarize: Print this many entries of each tensor.
+
+ Returns:
+ Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
+
+ Raises:
+ ValueError: If static checks determine `x` fails static_condition.
+ """
+ for rank in ranks:
+ assert_type(rank, dtypes.int32)
+
+ # Attempt to statically defined rank.
+ ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
+ if None not in ranks_static:
+ for rank_static in ranks_static:
+ if rank_static.ndim != 0:
+ raise ValueError('Rank must be a scalar.')
+
+ x_rank_static = x.get_shape().ndims
+ if x_rank_static is not None:
+ if not static_condition(x_rank_static, ranks_static):
+ raise ValueError(
+ 'Static rank condition failed', x_rank_static, ranks_static)
+ return control_flow_ops.no_op(name='static_checks_determined_all_ok')
+
+ condition = dynamic_condition(array_ops.rank(x), ranks)
+
+ # Add the condition that `rank` must have rank zero. Prevents the bug where
+ # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
+ for rank, rank_static in zip(ranks, ranks_static):
+ if rank_static is None:
+ this_data = ['Rank must be a scalar. Received rank: ', rank]
+ rank_check = assert_rank(rank, 0, data=this_data)
+ condition = control_flow_ops.with_dependencies([rank_check], condition)
+
+ return control_flow_ops.Assert(condition, data, summarize=summarize)
+
+
+def assert_rank_in(
+ x, ranks, data=None, summarize=None, message=None, name=None):
+ """Assert `x` has rank in `ranks`.
+
+ Example of adding a dependency to an operation:
+
+ ```python
+ with tf.control_dependencies([tf.assert_rank_in(x, (2, 4))]):
+ output = tf.reduce_sum(x)
+ ```
+
+ Args:
+ x: Numeric `Tensor`.
+ ranks: Iterable of scalar `Tensor` objects.
+ data: The tensors to print out if the condition is False. Defaults to
+ error message and first few entries of `x`.
+ summarize: Print this many entries of each tensor.
+ message: A string to prefix to the default message.
+ name: A name for this operation (optional).
+ Defaults to "assert_rank_in".
+
+ Returns:
+ Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
+ If static checks determine `x` has matching rank, a `no_op` is returned.
+
+ Raises:
+ ValueError: If static checks determine `x` has mismatched rank.
+ """
+ with ops.name_scope(
+ name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
+ x = ops.convert_to_tensor(x, name='x')
+ ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
+ message = message or ''
+
+ if data is None:
+ data = [
+ message, 'Tensor %s must have rank in' % x.name
+ ] + list(ranks) + [
+ 'Received shape: ', array_ops.shape(x)
+ ]
+
+ try:
+ assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
+ _dynamic_rank_in, data, summarize)
+
+ except ValueError as e:
+ if e.args[0] == 'Static rank condition failed':
+ raise ValueError(
+ '%s. Tensor %s must have rank in %s. Received rank %d, '
+ 'shape %s' % (message, x.name, e.args[2], e.args[1], x.get_shape()))
+ else:
+ raise
+
+ return assert_op
+
+
def assert_integer(x, message=None, name=None):
"""Assert that `x` is of integer dtype.