aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
blob: f9b9c77bbf7e2a8afdbfbd0929a68856b8aae51c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.scatter_nd."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import functools

import numpy as np

from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


def _AsType(v, vtype):
  return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)


def _FlatInnerDims(tensor, ndims=2):
  shape = list(tensor.shape)
  return tensor.reshape([
      functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)
  ] + shape[-ndims + 1:])


def _FlatOuterDims(tensor, ndims=2):
  shape = list(tensor.shape)
  return tensor.reshape(shape[:ndims - 1] + [
      functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)
  ])


def _NumpyScatterNd(ref, indices, updates, op):
  ixdim = indices.shape[-1]
  num_updates = indices.size // ixdim
  total_nd = len(ref.shape)
  slice_size = 1
  for i in range(ixdim, total_nd):
    slice_size *= ref.shape[i]
  flat_indices = _FlatInnerDims(indices)
  flat_updates = updates.reshape((num_updates, slice_size))
  output_flat = _FlatOuterDims(ref, ixdim + 1)
  for ix_updates, ix_output in enumerate(flat_indices):
    ix_output = tuple(ix_output)
    output_flat[ix_output] = op(output_flat[ix_output],
                                flat_updates[ix_updates])
  return output_flat.reshape(ref.shape)


def _NumpyUpdate(ref, indices, updates):
  return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)


def _NumpyAdd(ref, indices, updates):
  return _NumpyScatterNd(ref, indices, updates, lambda p, u: p + u)


def _NumpySub(ref, indices, updates):
  return _NumpyScatterNd(ref, indices, updates, lambda p, u: p - u)


def _NumpyMul(ref, indices, updates):
  return _NumpyScatterNd(ref, indices, updates, lambda p, u: p * u)


def _NumpyDiv(ref, indices, updates):
  return _NumpyScatterNd(ref, indices, updates, lambda p, u: p / u)


class StatefulScatterNdTest(test.TestCase):

  def _VariableRankTest(self,
                        np_scatter,
                        tf_scatter,
                        vtype,
                        itype,
                        repeat_indices=False):
    np.random.seed(8)
    ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)]
    indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)]
    with self.test_session(use_gpu=True):
      for ref_shape, indices_shape in zip(ref_shapes, indices_shapes):
        num_updates = indices_shape[0]
        ixdim = indices_shape[-1]

        indexable_area_shape = ()
        for i in range(ixdim):
          indexable_area_shape += (ref_shape[i],)
        all_indices = [
            list(coord)
            for coord, _ in np.ndenumerate(
                np.empty(indexable_area_shape, vtype))
        ]
        np.random.shuffle(all_indices)
        indices = np.array(all_indices[:num_updates])

        if num_updates > 1 and repeat_indices:
          indices = indices[:num_updates // 2]
          for _ in range(num_updates - num_updates // 2):
            indices = np.append(
                indices, [indices[np.random.randint(num_updates // 2)]], axis=0)
          np.random.shuffle(indices)
        indices = _AsType(indices[:num_updates], itype)

        updates_shape = (num_updates,)
        for i in range(ixdim, len(ref_shape)):
          updates_shape += (ref_shape[i],)
        updates = _AsType(np.random.randn(*(updates_shape)), vtype)
        ref = _AsType(np.random.randn(*(ref_shape)), vtype)

        # Scatter via numpy
        new = ref.copy()
        np_scatter(new, indices, updates)
        # Scatter via tensorflow
        ref_var = variables.Variable(ref)
        ref_var.initializer.run()
        tf_scatter(ref_var, indices, updates).eval()

        # Compare
        self.assertAllClose(new, ref_var.eval())

  def _VariableRankTests(self, np_scatter, tf_scatter):
    for vtype in (np.int32,
                  np.float32, np.float64,
                  np.complex64, np.complex128):
      for itype in (np.int32, np.int64):
        self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)

  def testSimple(self):
    indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
    updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
    ref = variables.Variable([0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32)
    expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
    scatter = state_ops.scatter_nd_update(ref, indices, updates)
    init = variables.global_variables_initializer()

    with self.test_session(use_gpu=True) as sess:
      sess.run(init)
      result = sess.run(scatter)
      self.assertAllClose(result, expected)

  def testSimpleResource(self):
    indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
    updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
    ref = resource_variable_ops.ResourceVariable(
        [0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32)
    expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
    scatter = state_ops.scatter_nd_update(ref, indices, updates)
    init = variables.global_variables_initializer()

    with self.test_session(use_gpu=True) as sess:
      sess.run(init)
      sess.run(scatter)
      self.assertAllClose(ref.eval(), expected)

  def testSimple2(self):
    indices = constant_op.constant([[1, 0], [1, 1]], dtype=dtypes.int32)
    updates = constant_op.constant([11., 12.], dtype=dtypes.float32)
    ref = variables.Variable(
        [[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32)
    expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
    scatter = state_ops.scatter_nd_update(ref, indices, updates)
    init = variables.global_variables_initializer()

    with self.test_session(use_gpu=True) as sess:
      sess.run(init)
      result = sess.run(scatter)
      self.assertAllClose(result, expected)

  def testSimple3(self):
    indices = constant_op.constant([[1]], dtype=dtypes.int32)
    updates = constant_op.constant([[11., 12.]], dtype=dtypes.float32)
    ref = variables.Variable(
        [[0., 0.], [0., 0.], [0., 0.]], dtype=dtypes.float32)
    expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
    scatter = state_ops.scatter_nd_update(ref, indices, updates)
    init = variables.global_variables_initializer()

    with self.test_session(use_gpu=True) as sess:
      sess.run(init)
      result = sess.run(scatter)
      self.assertAllClose(result, expected)

  def testVariableRankUpdate(self):
    self._VariableRankTests(_NumpyUpdate, state_ops.scatter_nd_update)

  def testVariableRankAdd(self):
    self._VariableRankTests(_NumpyAdd, state_ops.scatter_nd_add)

  def testVariableRankSub(self):
    self._VariableRankTests(_NumpySub, state_ops.scatter_nd_sub)

  # TODO(ebrevdo): Re-enable when we need ScatterNdMul.
  # def testVariableRankMul(self):
  #   self._VariableRankTests(_NumpyMul, state_ops.scatter_nd_mul)

  # TODO(ebrevdo): Re-enable when we need ScatterNdDiv.
  # def testVariableRankDiv(self):
  #   self._VariableRankTests(_NumpyDiv, state_ops.scatter_nd_div)

  def _ScatterRepeatIndicesTest(self, np_scatter, tf_scatter):
    for vtype in (np.int32, np.float32, np.float64):
      for itype in (np.int32, np.int64):
        self._VariableRankTest(
            np_scatter, tf_scatter, vtype, itype, repeat_indices=True)

  def testScatterRepeatIndices(self):
    """This tests scatter_add using indices that repeat."""
    self._ScatterRepeatIndicesTest(_NumpyAdd, state_ops.scatter_nd_add)
    self._ScatterRepeatIndicesTest(_NumpySub, state_ops.scatter_nd_sub)
    # TODO(ebrevdo): Re-enable when we need ScatterNdMul and ScatterNdDiv.
    # self._ScatterRepeatIndicesTest(_NumpyMul, state_ops.scatter_nd_mul)
    # self._ScatterRepeatIndicesTest(_NumpyDiv, state_ops.scatter_nd_div)

  # TODO(simister): Re-enable once binary size increase due to
  # extra templating is back under control and this op is re-enabled
  # def testBooleanScatterUpdate(self):
  #   with self.test_session(use_gpu=False) as session:
  #     var = tf.Variable([True, False])
  #     update0 = tf.scatter_nd_update(var, [[1]], [True])
  #     update1 = tf.scatter_nd_update(
  #         var, tf.constant(
  #             [[0]], dtype=tf.int64), [False])
  #     var.initializer.run()
  #     session.run([update0, update1])
  #     self.assertAllEqual([False, True], var.eval())

  def testScatterOutOfRangeCpu(self):
    # TODO(simister): Re-enable once binary size increase due to
    # scatter_nd ops is under control.
    #  tf.scatter_nd_mul, tf.scatter_nd_div,
    for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
               state_ops.scatter_nd_update):
      params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
      updates = np.array([-3, -4, -5]).astype(np.float32)
      with self.test_session(use_gpu=False):
        ref = variables.Variable(params)
        ref.initializer.run()

        # Indices all in range, no problem.
        indices = np.array([[2], [0], [5]])
        op(ref, indices, updates).eval()

        # Test some out of range errors.
        indices = np.array([[-1], [0], [5]])
        with self.assertRaisesOpError(
            r"Invalid indices: \[0,0\] = \[-1\] does not index into \[6\]"):
          op(ref, indices, updates).eval()

        indices = np.array([[2], [0], [6]])
        with self.assertRaisesOpError(
            r"Invalid indices: \[2,0\] = \[6\] does not index into \[6\]"):
          op(ref, indices, updates).eval()

  def testRank3ValidShape(self):
    indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    self.assertAllEqual(
        state_ops.scatter_nd_update(ref, indices,
                                    updates).get_shape().as_list(), shape)

  def testExtraIndicesDimensions(self):
    indices = array_ops.zeros([1, 1, 2], dtypes.int32)
    updates = array_ops.zeros([1, 1], dtypes.int32)
    shape = np.array([2, 2])
    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    scatter_update = state_ops.scatter_nd_update(ref, indices, updates)
    self.assertAllEqual(scatter_update.get_shape().as_list(), shape)

    expected_result = np.zeros([2, 2], dtype=np.int32)
    with self.test_session():
      ref.initializer.run()
      self.assertAllEqual(expected_result, scatter_update.eval())

  def testRank3InvalidShape1(self):
    indices = array_ops.zeros([3, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The outer \\d+ dimensions of indices\\.shape="):
      state_ops.scatter_nd_update(ref, indices, updates)

  def testRank3InvalidShape2(self):
    indices = array_ops.zeros([2, 2, 1], dtypes.int32)
    updates = array_ops.zeros([2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    ref = variables.Variable(array_ops.zeros(shape, dtypes.int32))
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The inner \\d+ dimensions of input\\.shape="):
      state_ops.scatter_nd_update(ref, indices, updates)

  def testConcurrentUpdates(self):
    num_updates = 10000
    update_values = np.random.rand(num_updates)
    ref = variables.Variable(np.zeros([2, 2]), dtype=dtypes.float64)
    indices = constant_op.constant([[0, 1]] * num_updates, dtype=dtypes.int32)
    updates = constant_op.constant(update_values, dtype=dtypes.float64)

    expected_result = np.zeros([2, 2], dtype=np.float64)
    expected_result[0, 1] = np.sum(update_values)

    scatter = state_ops.scatter_nd_add(ref, indices, updates)
    init = variables.global_variables_initializer()

    with session.Session() as sess:
      sess.run(init)
      result = sess.run(scatter)
      assert np.allclose(result, expected_result)

  # TODO(fpmc): Re-enable this test when gpu_pip test actually runs on a GPU.
  def _disabledTestScatterOutOfRangeGpu(self):
    if not test.IsBuiltWithCuda():
      return
    # TODO(simister): Re-enable once binary size increase due to
    # scatter_nd ops is under control.
    # tf.scatter_nd_mul, tf.scatter_nd_div,
    for op in (state_ops.scatter_nd_add, state_ops.scatter_nd_sub,
               state_ops.scatter_nd_update):
      params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
      updates = np.array([-3, -4, -5]).astype(np.float32)
      # With GPU, the code ignores indices that are out of range.
      # We don't test the implementation; just test there's no failures.
      with self.test_session(force_gpu=True):
        ref = variables.Variable(params)
        ref.initializer.run()

        # Indices all in range, no problem.
        indices = np.array([2, 0, 5])
        op(ref, indices, updates).eval()

        # Indices out of range should not fail.
        indices = np.array([-1, 0, 5])
        op(ref, indices, updates).eval()
        indices = np.array([2, 0, 6])
        op(ref, indices, updates).eval()


class ScatterNdTest(test.TestCase):
  non_aliasing_add_test = False

  def scatter_nd(self, indices, updates, shape, input_=None):
    del input_  # input_ is not used in scatter_nd
    return array_ops.scatter_nd(indices, updates, shape)

  @test_util.run_in_graph_and_eager_modes
  def testInvalidShape(self):
    # TODO(apassos) figure out how to unify these errors
    with self.assertRaises(errors.InvalidArgumentError
                           if context.executing_eagerly() else ValueError):
      array_ops.scatter_nd(indices=[0],  # this should be indices=[[0]]
                           updates=[0.0],
                           shape=[1])

  def testString(self):
    indices = constant_op.constant([[4], [3], [1], [7]],
                                   dtype=dtypes.int32)
    updates = constant_op.constant(["four", "three", "one", "seven"],
                                   dtype=dtypes.string)
    expected = np.array([b"", b"one", b"", b"three", b"four",
                         b"", b"", b"seven"])
    scatter = self.scatter_nd(indices, updates, shape=(8,))
    with self.test_session() as sess:
      result = sess.run(scatter)
      self.assertAllEqual(expected, result)

    # Same indice is updated twice by same value.
    indices = constant_op.constant([[4], [3], [3], [7]],
                                   dtype=dtypes.int32)
    updates = constant_op.constant(["a", "b", "b", "c"],
                                   dtype=dtypes.string)
    expected = np.array([b"", b"", b"", b"bb", b"a", b"", b"", b"c"])
    scatter = self.scatter_nd(indices, updates, shape=(8,))
    with self.test_session() as sess:
      result = sess.run(scatter)
      self.assertAllEqual(expected, result)

    # Same indice is updated twice by different value.
    indices = constant_op.constant([[4], [3], [3], [7]],
                                   dtype=dtypes.int32)
    updates = constant_op.constant(["a", "b", "c", "d"],
                                   dtype=dtypes.string)
    expected = [np.array([b"", b"", b"", b"bc", b"a", b"", b"", b"d"]),
                np.array([b"", b"", b"", b"cb", b"a", b"", b"", b"d"])]
    scatter = self.scatter_nd(indices, updates, shape=(8,))
    with self.test_session() as sess:
      result = sess.run(scatter)
      self.assertTrue(np.array_equal(result, expected[0]) or
                      np.array_equal(result, expected[1]))

  def testRank3ValidShape(self):
    indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    self.assertAllEqual(
        self.scatter_nd(indices, updates, shape).get_shape().as_list(), shape)

  def testExtraIndicesDimensions(self):
    indices = array_ops.zeros([1, 1, 2], dtypes.int32)
    updates = array_ops.zeros([1, 1], dtypes.int32)
    shape = np.array([2, 2])
    scatter = self.scatter_nd(indices, updates, shape)
    self.assertAllEqual(scatter.get_shape().as_list(), shape)
    expected_result = np.zeros([2, 2], dtype=np.int32)
    with self.test_session():
      self.assertAllEqual(expected_result, scatter.eval())

  def testUndefinedIndicesShape(self):
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
    shape = constant_op.constant([2, 2, 2], dtypes.int32)
    self.scatter_nd(indices, updates, shape)

  def testUndefinedUpdatesShape(self):
    indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
    updates = array_ops.placeholder(dtypes.int32, shape=None)
    shape = constant_op.constant([2, 2, 2], dtypes.int32)
    self.scatter_nd(indices, updates, shape)

  def testUndefinedOutputShape(self):
    indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
    updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2])
    shape = array_ops.placeholder(dtypes.int32, shape=[None])
    self.scatter_nd(indices, updates, shape)

  def testEmptyOutputShape1(self):
    indices = array_ops.zeros([2, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = constant_op.constant([0, 3, 2], dtypes.int32)

    with self.assertRaisesWithPredicateMatch(
        ValueError, "Indices and updates specified for empty output shape"):
      self.scatter_nd(indices, updates, shape)

  def testEmptyOutputShape2(self):
    indices = array_ops.placeholder(dtypes.int32, shape=None)
    updates = array_ops.placeholder(dtypes.int32, shape=None)
    shape = constant_op.constant([0, 3, 2], dtypes.int32)

    with self.test_session():
      with self.assertRaisesOpError(
          "Indices and updates specified for empty output"):
        self.scatter_nd(indices, updates, shape).eval(feed_dict={
            indices: np.zeros([2, 2, 2], dtype=np.int32),
            updates: np.zeros([2, 2, 2], dtype=np.int32)
        })

  def testEmptyOutputShape3(self):
    indices = array_ops.zeros([0], dtypes.int32)
    updates = array_ops.zeros([0], dtypes.int32)
    shape = constant_op.constant([0], dtypes.int32)
    scatter = self.scatter_nd(indices, updates, shape)

    with self.test_session():
      self.assertEqual(scatter.eval().size, 0)

  def testRank3InvalidShape1(self):
    indices = array_ops.zeros([3, 2, 2], dtypes.int32)
    updates = array_ops.zeros([2, 2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The outer \\d+ dimensions of indices\\.shape="):
      self.scatter_nd(indices, updates, shape)

  def testRank3InvalidShape2(self):
    indices = array_ops.zeros([2, 2, 1], dtypes.int32)
    updates = array_ops.zeros([2, 2], dtypes.int32)
    shape = np.array([2, 2, 2])
    with self.assertRaisesWithPredicateMatch(
        ValueError, "The inner \\d+ dimensions of (input|output)\\.shape="):
      self.scatter_nd(indices, updates, shape)

  def testGradientsRank2ElementUpdate(self):
    indices = constant_op.constant([[0, 0], [1, 1]], dtype=dtypes.int32)
    updates = constant_op.constant([1, 4], dtype=dtypes.float64)
    shape = constant_op.constant([2, 2], dtype=dtypes.int32)
    input_ = array_ops.zeros(shape, dtype=dtypes.float64)
    outputs = self.scatter_nd(indices, updates, shape, input_)

    grad_vals = constant_op.constant([[1, 2], [3, 4]], dtype=dtypes.float64)
    updates_grad, input_grad = gradients_impl.gradients(
        [outputs], [updates, input_], [grad_vals])
    expected_updates_grad = np.array([1, 4], dtype=np.float64)
    expected_input_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
    with self.test_session():
      self.assertAllEqual(expected_updates_grad, updates_grad.eval())
      if self.non_aliasing_add_test:
        self.assertAllEqual(expected_input_grad, input_grad.eval())

  def testGradientsRank2SliceUpdate(self):
    indices = constant_op.constant([[1], [0]], dtype=dtypes.int32)
    updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
    shape = constant_op.constant([2, 2], dtype=dtypes.int32)
    input_ = array_ops.zeros(shape, dtype=dtypes.float64)
    outputs = self.scatter_nd(indices, updates, shape, input_)

    grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64)
    updates_grad, input_grad = gradients_impl.gradients(
        [outputs], [updates, input_], [grad_vals])
    expected_updates_grad = np.array([[1, 2], [3, 4]], dtype=np.float64)
    expected_input_grad = np.array([[3, 4], [1, 2]], dtype=np.float64)
    with self.test_session():
      self.assertAllEqual(expected_updates_grad, updates_grad.eval())
      if self.non_aliasing_add_test:
        self.assertAllEqual(expected_input_grad, input_grad.eval())

  def testGradientsRank3SliceUpdate(self):
    indices = constant_op.constant(
        [[[0, 1], [1, 0]], [[0, 0], [1, 1]]], dtype=dtypes.int32)
    updates = constant_op.constant(
        [[[5, 7], [2, 4]], [[1, 3], [6, 8]]], dtype=dtypes.float64)
    shape = constant_op.constant([2, 2, 2], dtype=dtypes.int32)
    input_ = array_ops.zeros(shape, dtype=dtypes.float64)
    outputs = self.scatter_nd(indices, updates, shape, input_)

    grad_vals = constant_op.constant(
        [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=dtypes.float64)
    updates_grad, input_grad = gradients_impl.gradients(
        [outputs], [updates, input_], [grad_vals])
    expected_updates_grad = np.array(
        [[[3, 4], [5, 6]], [[1, 2], [7, 8]]], dtype=np.float64)
    expected_input_grad = np.array(
        [[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=np.float64)
    with self.test_session():
      self.assertAllEqual(expected_updates_grad, updates_grad.eval())
      if self.non_aliasing_add_test:
        self.assertAllEqual(expected_input_grad, input_grad.eval())

  def testGradientsRank7SliceUpdate(self):
    indices = constant_op.constant(
        [[[
            [[[[0, 0, 0, 0, 0, 1], [0, 0, 1, 0, 0, 0]]]],
            [[[[0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1]]]]
        ]]], dtype=dtypes.int32)
    updates = constant_op.constant(
        [[[
            [[[[5, 6], [2, 4]]]],
            [[[[1, 3], [6, 8]]]]
        ]]], dtype=dtypes.float64)
    shape = constant_op.constant([1, 1, 2, 1, 1, 2, 2], dtype=dtypes.int32)
    input_ = array_ops.zeros(shape, dtype=dtypes.float64)
    outputs = self.scatter_nd(indices, updates, shape, input_)

    grad_vals = constant_op.constant(
        [[[
            [[[[1, 2], [3, 4]]]],
            [[[[5, 6], [7, 8]]]]
        ]]], dtype=dtypes.float64)
    updates_grad, input_grad = gradients_impl.gradients(
        [outputs], [updates, input_], [grad_vals])
    expected_updates_grad = np.array(
        [[[
            [[[[3, 4], [5, 6]]]],
            [[[[1, 2], [7, 8]]]]
        ]]], dtype=np.float64)
    expected_input_grad = np.array(
        [[[
            [[[[1, 2], [3, 4]]]],
            [[[[5, 6], [7, 8]]]]
        ]]], dtype=np.float64)
    with self.test_session():
      self.assertAllEqual(expected_updates_grad, updates_grad.eval())
      if self.non_aliasing_add_test:
        self.assertAllEqual(expected_input_grad, input_grad.eval())

  def testScatterNdRepatedIndicesAdd(self):
    indices = array_ops.zeros([100000, 1], dtypes.int32)
    values = np.random.randn(100000)
    shape = [1]
    with self.test_session():
      val = self.scatter_nd(indices, values, shape).eval()
    self.assertAllClose([np.sum(values)], val)

  def testSmokeScatterNdBatch2DSliceDim2(self):
    with self.test_session():
      indices = array_ops.zeros([3, 5, 2], dtype=dtypes.int32)
      values = array_ops.zeros([3, 5, 7])
      shape = [4, 6, 7]
      self.scatter_nd(indices, values, shape).eval()

  def testSmokeScatterNdBatch1DSliceDim2(self):
    with self.test_session():
      indices = array_ops.zeros([0, 2], dtype=dtypes.int32)
      values = array_ops.zeros([0, 7])
      shape = [4, 6, 7]
      self.scatter_nd(indices, values, shape).eval()

  def testSmokeScatterNdBatch1DSliceDim3ShapeRank7(self):
    with self.test_session():
      indices = array_ops.zeros([1, 3], dtype=dtypes.int32)
      values = array_ops.zeros([1, 6, 7, 8, 9])
      shape = [3, 4, 5, 6, 7, 8, 9]
      self.scatter_nd(indices, values, shape).eval()

  def testSmokeScatterNdBatch2DSliceDim3ShapeRank7(self):
    with self.test_session():
      indices = array_ops.zeros([1, 2, 3], dtype=dtypes.int32)
      values = array_ops.zeros([1, 2, 6, 7, 8, 9])
      shape = [3, 4, 5, 6, 7, 8, 9]
      self.scatter_nd(indices, values, shape).eval()


class ScatterNdNonAliasingAddTest(ScatterNdTest):
  non_aliasing_add_test = True

  def scatter_nd(self, indices, updates, shape, input_=None):
    input_ = (input_ if input_ is not None else array_ops.zeros(
        shape, dtype=updates.dtype))
    return array_ops.scatter_nd_non_aliasing_add(input_, indices, updates)

  def testString(self):
    # Not supported yet.
    pass


if __name__ == "__main__":
  test.main()