aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/check_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/check_ops.py')
-rw-r--r--tensorflow/python/ops/check_ops.py79
1 files changed, 7 insertions, 72 deletions
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index 7e509f72c1..ceee009104 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -48,7 +48,6 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_util
@@ -97,11 +96,10 @@ def _maybe_constant_value_string(t):
def _assert_static(condition, data):
- """Raises a InvalidArgumentError with as much information as possible."""
+ """Raises a static ValueError with as much information as possible."""
if not condition:
data_static = [_maybe_constant_value_string(x) for x in data]
- raise errors.InvalidArgumentError(node_def=None, op=None,
- message='\n'.join(data_static))
+ raise ValueError('\n'.join(data_static))
def assert_proper_iterable(values):
@@ -305,60 +303,11 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
Returns:
Op that raises `InvalidArgumentError` if `x == y` is False.
- @compatibility{eager} returns None
-
- Raises:
- InvalidArgumentError if the check can be performed immediately and
- `x == y` is False. The check can be performed immediately during
- eager execution or if `x` and `y` are statically known.
"""
message = message or ''
with ops.name_scope(name, 'assert_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
-
- if context.in_eager_mode():
- eq = math_ops.equal(x, y)
- condition = math_ops.reduce_all(eq)
- if not condition:
- # Prepare a message with first elements of x and y
- summary_msg = ''
- if summarize:
- # reshape((-1,)) is the fastest way to get a flat array view.
- x_np = x.numpy().reshape((-1,))
- y_np = y.numpy().reshape((-1,))
- x_sum = min(x_np.size, summarize)
- y_sum = min(y_np.size, summarize)
- summary_msg = ('First %d elements of x:\n%s\n'
- 'First %d elements of y:\n%s\n' %
- (x_sum, x_np[:x_sum],
- y_sum, y_np[:y_sum]))
-
- # Get the values that actually differed and their indices
- mask = math_ops.logical_not(eq)
- indices = array_ops.where(mask)
- indices_np = indices.numpy()
- x_vals = array_ops.boolean_mask(x, mask)
- y_vals = array_ops.boolean_mask(y, mask)
- diff_to_print = 0
- if summarize:
- diff_to_print = min(summarize, indices_np.size)
-
- raise errors.InvalidArgumentError(
- node_def=None, op=None,
- message=('%s\nCondition x == y did not hold.\n'
- 'Indices of first %s different values:\n%s\n'
- 'Corresponding x values:\n%s\n'
- 'Corresponding y values:\n%s\n'
- '%s'
- %
- (message or '',
- diff_to_print, indices_np[:diff_to_print],
- x_vals.numpy().reshape((-1,))[:diff_to_print],
- y_vals.numpy().reshape((-1,))[:diff_to_print],
- summary_msg)))
- return
-
if data is None:
data = [
message,
@@ -407,19 +356,12 @@ def assert_none_equal(
with ops.name_scope(name, 'assert_none_equal', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
- if context.in_eager_mode():
- x_name = 'x'
- y_name = 'y'
- else:
- x_name = x.name
- y_name = y.name
-
if data is None:
data = [
message,
- 'Condition x != y did not hold for every single element:',
- 'x (%s) = ' % x_name, x,
- 'y (%s) = ' % y_name, y
+ 'Condition x != y did not hold for every single element:'
+ 'x (%s) = ' % x.name, x,
+ 'y (%s) = ' % y.name, y
]
condition = math_ops.reduce_all(math_ops.not_equal(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)
@@ -455,18 +397,11 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
with ops.name_scope(name, 'assert_less', [x, y, data]):
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y')
- if context.in_eager_mode():
- x_name = 'x'
- y_name = 'y'
- else:
- x_name = x.name
- y_name = y.name
-
if data is None:
data = [
message,
- 'Condition x < y did not hold element-wise:',
- 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
+ 'Condition x < y did not hold element-wise:'
+ 'x (%s) = ' % x.name, x, 'y (%s) = ' % y.name, y
]
condition = math_ops.reduce_all(math_ops.less(x, y))
return control_flow_ops.Assert(condition, data, summarize=summarize)