diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/scatter_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/scatter_ops_test.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py index af541a96c1..c4c248186c 100644 --- a/tensorflow/python/kernel_tests/scatter_ops_test.py +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -63,6 +63,17 @@ class ScatterTest(tf.test.TestCase): ref[indices] -= updates self._VariableRankTest(sub, tf.scatter_sub) + def testBooleanScatterUpdate(self): + with self.test_session() as session: + var = tf.Variable([True, False]) + update0 = tf.scatter_update(var, 1, True) + update1 = tf.scatter_update(var, tf.constant(0, dtype=tf.int64), False) + var.initializer.run() + + session.run([update0, update1]) + + self.assertAllEqual([False, True], var.eval()) + if __name__ == "__main__": tf.test.main() |