aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2016-11-07 16:01:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-08 16:12:19 -0800
commitfd05b5ebc56316eb6ac9fcb74234979fee2fc5f9 (patch)
tree051f2d5145673d8bbebe0646434860c888991815 /tensorflow/python/kernel_tests/scatter_nd_ops_test.py
parentaac685b7209b03ffd356ea6860366467b335d402 (diff)
Changes to scatter_nd ops
* Rewrite CPU impl to be single-threaded and use vectorization; avoids race conditions. Removes use of the generator. * Remove scatter_nd_mul and scatter_nd_div to reduce binary size until we figure out a better way to reduce the templating pain * Modify scatter_nd to add for repeated indices as opposed to update (this is the appropriate gradient for gather_nd, for example) * Clean up docstrings. Change: 138452341
Diffstat (limited to 'tensorflow/python/kernel_tests/scatter_nd_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py64
1 files changed, 40 insertions, 24 deletions
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index 0461758d27..3d2ac798cd 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -78,7 +78,7 @@ def _NumpyDiv(ref, indices, updates):
return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)
-class ScatterTest(tf.test.TestCase):
+class ScatterNdTest(tf.test.TestCase):
def _VariableRankTest(self,
np_scatter,
@@ -145,11 +145,13 @@ class ScatterTest(tf.test.TestCase):
def testVariableRankSub(self):
self._VariableRankTests(_NumpySub, tf.scatter_nd_sub)
- def testVariableRankMul(self):
- self._VariableRankTests(_NumpyMul, tf.scatter_nd_mul)
+ # TODO(simister): Re-enable once binary size increase due to
+ # scatter_nd ops is under control.
+ # def testVariableRankMul(self):
+ # self._VariableRankTests(_NumpyMul, tf.scatter_nd_mul)
- def testVariableRankDiv(self):
- self._VariableRankTests(_NumpyDiv, tf.scatter_nd_div)
+ # def testVariableRankDiv(self):
+ # self._VariableRankTests(_NumpyDiv, tf.scatter_nd_div)
def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter):
for vtype in (np.float32, np.float64):
@@ -167,25 +169,29 @@ class ScatterTest(tf.test.TestCase):
"""This tests scatter_add using indices that repeat."""
self._ScatterRepeatIndicesTest(_NumpyAdd, tf.scatter_nd_add)
self._ScatterRepeatIndicesTest(_NumpySub, tf.scatter_nd_sub)
- self._ScatterRepeatIndicesTest(_NumpyMul, tf.scatter_nd_mul)
- self._ScatterRepeatIndicesTest(_NumpyDiv, tf.scatter_nd_div)
-
- def testBooleanScatterUpdate(self):
- with self.test_session(use_gpu=False) as session:
- var = tf.Variable([True, False])
- update0 = tf.scatter_nd_update(var, [[1]], [True])
- update1 = tf.scatter_nd_update(
- var, tf.constant(
- [[0]], dtype=tf.int64), [False])
- var.initializer.run()
-
- session.run([update0, update1])
-
- self.assertAllEqual([False, True], var.eval())
+ # TODO(simister): Re-enable once binary size increase due to
+ # extra templating is back under control.
+ # self._ScatterRepeatIndicesTest(_NumpyMul, tf.scatter_nd_mul)
+ # self._ScatterRepeatIndicesTest(_NumpyDiv, tf.scatter_nd_div)
+
+ # TODO(simister): Re-enable once binary size increase due to
+ # extra templating is back under control and this op is re-enabled
+ # def testBooleanScatterUpdate(self):
+ # with self.test_session(use_gpu=False) as session:
+ # var = tf.Variable([True, False])
+ # update0 = tf.scatter_nd_update(var, [[1]], [True])
+ # update1 = tf.scatter_nd_update(
+ # var, tf.constant(
+ # [[0]], dtype=tf.int64), [False])
+ # var.initializer.run()
+ # session.run([update0, update1])
+ # self.assertAllEqual([False, True], var.eval())
def testScatterOutOfRangeCpu(self):
- for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul,
- tf.scatter_nd_div, tf.scatter_nd_update):
+ # TODO(simister): Re-enable once binary size increase due to
+ # scatter_nd ops is under control.
+ # tf.scatter_nd_mul, tf.scatter_nd_div,
+ for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_update):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
with self.test_session(use_gpu=False):
@@ -355,8 +361,10 @@ class ScatterTest(tf.test.TestCase):
def _disabledTestScatterOutOfRangeGpu(self):
if not tf.test.IsBuiltWithCuda():
return
- for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_mul,
- tf.scatter_nd_div, tf.scatter_nd_update):
+ # TODO(simister): Re-enable once binary size increase due to
+ # scatter_nd ops is under control.
+ # tf.scatter_nd_mul, tf.scatter_nd_div,
+ for op in (tf.scatter_nd_add, tf.scatter_nd_sub, tf.scatter_nd_update):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
# With GPU, the code ignores indices that are out of range.
@@ -375,6 +383,14 @@ class ScatterTest(tf.test.TestCase):
indices = np.array([2, 0, 6])
op(ref, indices, updates).eval()
+ def testScatterNdRepatedIndicesAdd(self):
+ indices = tf.zeros([100000, 1], tf.int32)
+ values = np.random.randn(100000)
+ shape = [1]
+ with self.test_session():
+ val = tf.scatter_nd(indices, values, shape).eval()
+ self.assertAllClose([np.sum(values)], val)
+
if __name__ == "__main__":
tf.test.main()