diff options
Diffstat (limited to 'tensorflow/python/ops/check_ops.py')
-rw-r--r-- | tensorflow/python/ops/check_ops.py | 79 |
1 files changed, 7 insertions, 72 deletions
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 7e509f72c1..ceee009104 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -48,7 +48,6 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util @@ -97,11 +96,10 @@ def _maybe_constant_value_string(t): def _assert_static(condition, data): - """Raises a InvalidArgumentError with as much information as possible.""" + """Raises a static ValueError with as much information as possible.""" if not condition: data_static = [_maybe_constant_value_string(x) for x in data] - raise errors.InvalidArgumentError(node_def=None, op=None, - message='\n'.join(data_static)) + raise ValueError('\n'.join(data_static)) def assert_proper_iterable(values): @@ -305,60 +303,11 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): Returns: Op that raises `InvalidArgumentError` if `x == y` is False. - @compatibility{eager} returns None - - Raises: - InvalidArgumentError if the check can be performed immediately and - `x == y` is False. The check can be performed immediately during - eager execution or if `x` and `y` are statically known. """ message = message or '' with ops.name_scope(name, 'assert_equal', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - - if context.in_eager_mode(): - eq = math_ops.equal(x, y) - condition = math_ops.reduce_all(eq) - if not condition: - # Prepare a message with first elements of x and y - summary_msg = '' - if summarize: - # reshape((-1,)) is the fastest way to get a flat array view. - x_np = x.numpy().reshape((-1,)) - y_np = y.numpy().reshape((-1,)) - x_sum = min(x_np.size, summarize) - y_sum = min(y_np.size, summarize) - summary_msg = ('First %d elements of x:\n%s\n' - 'First %d elements of y:\n%s\n' % - (x_sum, x_np[:x_sum], - y_sum, y_np[:y_sum])) - - # Get the values that actually differed and their indices - mask = math_ops.logical_not(eq) - indices = array_ops.where(mask) - indices_np = indices.numpy() - x_vals = array_ops.boolean_mask(x, mask) - y_vals = array_ops.boolean_mask(y, mask) - diff_to_print = 0 - if summarize: - diff_to_print = min(summarize, indices_np.size) - - raise errors.InvalidArgumentError( - node_def=None, op=None, - message=('%s\nCondition x == y did not hold.\n' - 'Indices of first %s different values:\n%s\n' - 'Corresponding x values:\n%s\n' - 'Corresponding y values:\n%s\n' - '%s' - % - (message or '', - diff_to_print, indices_np[:diff_to_print], - x_vals.numpy().reshape((-1,))[:diff_to_print], - y_vals.numpy().reshape((-1,))[:diff_to_print], - summary_msg))) - return - if data is None: data = [ message, @@ -407,19 +356,12 @@ def assert_none_equal( with ops.name_scope(name, 'assert_none_equal', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - if context.in_eager_mode(): - x_name = 'x' - y_name = 'y' - else: - x_name = x.name - y_name = y.name - if data is None: data = [ message, - 'Condition x != y did not hold for every single element:', - 'x (%s) = ' % x_name, x, - 'y (%s) = ' % y_name, y + 'Condition x != y did not hold for every single element:' + 'x (%s) = ' % x.name, x, + 'y (%s) = ' % y.name, y ] condition = math_ops.reduce_all(math_ops.not_equal(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) @@ -455,18 +397,11 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None): with ops.name_scope(name, 'assert_less', [x, y, data]): x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y') - if context.in_eager_mode(): - x_name = 'x' - y_name = 'y' - else: - x_name = x.name - y_name = y.name - if data is None: data = [ message, - 'Condition x < y did not hold element-wise:', - 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y + 'Condition x < y did not hold element-wise:' + 'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y ] condition = math_ops.reduce_all(math_ops.less(x, y)) return control_flow_ops.Assert(condition, data, summarize=summarize) |