diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/scatter_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/scatter_ops_test.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py index c70a4ffce7..1a0fa744ae 100644 --- a/tensorflow/python/kernel_tests/scatter_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -159,7 +159,13 @@ class ScatterTest(test.TestCase): # Clips small values to avoid division by zero. def clip_small_values(x): - return 1e-4 * np.sign(x) if np.abs(x) < 1e-4 else x + threshold = 1e-4 + sign = np.sign(x) + + if isinstance(x, np.int32): + threshold = 1 + sign = np.random.choice([-1, 1]) + return threshold * sign if np.abs(x) < threshold else x updates = np.vectorize(clip_small_values)(updates) old = _AsType(np.random.randn(*((first_dim,) + extra_shape)), vtype) @@ -181,7 +187,11 @@ class ScatterTest(test.TestCase): tf_scatter, repeat_indices=False, updates_are_scalar=False): - for vtype in (np.float32, np.float64): + vtypes = [np.float32, np.float64] + if tf_scatter != state_ops.scatter_div: + vtypes.append(np.int32) + + for vtype in vtypes: for itype in (np.int32, np.int64): self._VariableRankTest(tf_scatter, vtype, itype, repeat_indices, updates_are_scalar) |