aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/scatter_nd_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py27
1 files changed, 25 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index f9b9c77bbf..f2f3023469 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -268,12 +268,12 @@ class StatefulScatterNdTest(test.TestCase):
# Test some out of range errors.
indices = np.array([[-1], [0], [5]])
with self.assertRaisesOpError(
- r"Invalid indices: \[0,0\] = \[-1\] does not index into \[6\]"):
+ r"indices\[0\] = \[-1\] does not index into shape \[6\]"):
op(ref, indices, updates).eval()
indices = np.array([[2], [0], [6]])
with self.assertRaisesOpError(
- r"Invalid indices: \[2,0\] = \[6\] does not index into \[6\]"):
+ r"indices\[2\] = \[6\] does not index into shape \[6\]"):
op(ref, indices, updates).eval()
def testRank3ValidShape(self):
@@ -370,6 +370,29 @@ class ScatterNdTest(test.TestCase):
return array_ops.scatter_nd(indices, updates, shape)
@test_util.run_in_graph_and_eager_modes
+ def testBool(self):
+ indices = constant_op.constant(
+ [[4], [3], [1], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant(
+ [False, True, False, True], dtype=dtypes.bool)
+ expected = np.array(
+ [False, False, False, True, False, False, False, True])
+ scatter = self.scatter_nd(indices, updates, shape=(8,))
+ result = self.evaluate(scatter)
+ self.assertAllEqual(expected, result)
+
+ # Same indice is updated twice by same value.
+ indices = constant_op.constant(
+ [[4], [3], [3], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant(
+ [False, True, True, True], dtype=dtypes.bool)
+ expected = np.array([
+ False, False, False, True, False, False, False, True])
+ scatter = self.scatter_nd(indices, updates, shape=(8,))
+ result = self.evaluate(scatter)
+ self.assertAllEqual(expected, result)
+
+ @test_util.run_in_graph_and_eager_modes
def testInvalidShape(self):
# TODO(apassos) figure out how to unify these errors
with self.assertRaises(errors.InvalidArgumentError