aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/scatter_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/scatter_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index c70a4ffce7..1a0fa744ae 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -159,7 +159,13 @@ class ScatterTest(test.TestCase):
# Clips small values to avoid division by zero.
def clip_small_values(x):
- return 1e-4 * np.sign(x) if np.abs(x) < 1e-4 else x
+ threshold = 1e-4
+ sign = np.sign(x)
+
+ if isinstance(x, np.int32):
+ threshold = 1
+ sign = np.random.choice([-1, 1])
+ return threshold * sign if np.abs(x) < threshold else x
updates = np.vectorize(clip_small_values)(updates)
old = _AsType(np.random.randn(*((first_dim,) + extra_shape)), vtype)
@@ -181,7 +187,11 @@ class ScatterTest(test.TestCase):
tf_scatter,
repeat_indices=False,
updates_are_scalar=False):
- for vtype in (np.float32, np.float64):
+ vtypes = [np.float32, np.float64]
+ if tf_scatter != state_ops.scatter_div:
+ vtypes.append(np.int32)
+
+ for vtype in vtypes:
for itype in (np.int32, np.int64):
self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices,
updates_are_scalar)