aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/scatter_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/scatter_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index af541a96c1..c4c248186c 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -63,6 +63,17 @@ class ScatterTest(tf.test.TestCase):
ref[indices] -= updates
self._VariableRankTest(sub, tf.scatter_sub)
+ def testBooleanScatterUpdate(self):
+ with self.test_session() as session:
+ var = tf.Variable([True, False])
+ update0 = tf.scatter_update(var, 1, True)
+ update1 = tf.scatter_update(var, tf.constant(0, dtype=tf.int64), False)
+ var.initializer.run()
+
+ session.run([update0, update1])
+
+ self.assertAllEqual([False, True], var.eval())
+
if __name__ == "__main__":
tf.test.main()