From f41959ccb2d9d4c722fe8fc3351401d53bcf4900 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Fri, 6 Nov 2015 16:27:58 -0800 Subject: TensorFlow: Initial commit of TensorFlow library. TensorFlow is an open source software library for numerical computation using data flow graphs. Base CL: 107276108 --- tensorflow/python/kernel_tests/scatter_ops_test.py | 49 ++++++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 tensorflow/python/kernel_tests/scatter_ops_test.py (limited to 'tensorflow/python/kernel_tests/scatter_ops_test.py') diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py new file mode 100644 index 0000000000..dd645819a3 --- /dev/null +++ b/tensorflow/python/kernel_tests/scatter_ops_test.py @@ -0,0 +1,49 @@ +"""Tests for tensorflow.ops.tf.scatter.""" +import tensorflow.python.platform + +import numpy as np +import tensorflow as tf + + +class ScatterTest(tf.test.TestCase): + + def _VariableRankTest(self, np_scatter, tf_scatter): + np.random.seed(8) + with self.test_session(): + for indices_shape in (), (2,), (2, 3), (2, 3, 4): + for extra_shape in (), (5,), (5, 6): + # Generate random indices with no duplicates for easy numpy comparison + size = np.prod(indices_shape, dtype=np.int32) + indices = np.arange(2 * size) + np.random.shuffle(indices) + indices = indices[:size].reshape(indices_shape) + updates = np.random.randn(*(indices_shape + extra_shape)) + old = np.random.randn(*((2 * size,) + extra_shape)) + # Scatter via numpy + new = old.copy() + np_scatter(new, indices, updates) + # Scatter via tensorflow + ref = tf.Variable(old) + ref.initializer.run() + tf_scatter(ref, indices, updates).eval() + # Compare + self.assertAllClose(ref.eval(), new) + + def testVariableRankUpdate(self): + def update(ref, indices, updates): + ref[indices] = updates + self._VariableRankTest(update, tf.scatter_update) + + def testVariableRankAdd(self): + def add(ref, indices, updates): + ref[indices] += updates + self._VariableRankTest(add, tf.scatter_add) + + def testVariableRankSub(self): + def sub(ref, indices, updates): + ref[indices] -= updates + self._VariableRankTest(sub, tf.scatter_sub) + + +if __name__ == "__main__": + tf.test.main() -- cgit v1.2.3