aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2/optimizer_v2.py
blob: dcb5bb6416a8b37c3f6946e1e56a2a9f95e759da (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Version 2 of class Optimizer."""
# pylint: disable=g-bad-name

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc

from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import checkpointable
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training import slot_creator
from tensorflow.python.util import nest


class _OptimizableVariable(object):
  """Interface for abstracting over variables in the optimizers."""

  @abc.abstractmethod
  def target(self):
    """Returns the optimization target for this variable."""
    raise NotImplementedError("Calling an abstract method.")

  @abc.abstractmethod
  def update_op(self, optimizer, g, *args):
    """Returns the update ops for updating the variable."""
    raise NotImplementedError("Calling an abstract method.")


class _RefVariableProcessor(_OptimizableVariable):
  """Processor for Variable."""

  def __init__(self, v):
    self._v = v

  def target(self):
    return self._v._ref()  # pylint: disable=protected-access

  def update_op(self, optimizer, g, *args):
    if isinstance(g, ops.Tensor):
      update_op = optimizer._apply_dense(g, self._v, *args)  # pylint: disable=protected-access
      if self._v.constraint is not None:
        with ops.control_dependencies([update_op]):
          return self._v.assign(self._v.constraint(self._v))
      else:
        return update_op
    else:
      assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
                                                "tensor nor IndexedSlices.")
      if self._v.constraint is not None:
        raise RuntimeError(
            "Cannot use a constraint function on a sparse variable.")
      # pylint: disable=protected-access
      return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)


class _DenseReadResourceVariableProcessor(_OptimizableVariable):
  """Processor for dense ResourceVariables."""

  def __init__(self, v):
    self._v = v

  def target(self):
    return self._v

  def update_op(self, optimizer, g, *args):
    # pylint: disable=protected-access
    update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
    if self._v.constraint is not None:
      with ops.control_dependencies([update_op]):
        return self._v.assign(self._v.constraint(self._v))
    else:
      return update_op


class _DenseResourceVariableProcessor(_OptimizableVariable):
  """Processor for dense ResourceVariables."""

  def __init__(self, v):
    self._v = v

  def target(self):
    return self._v

  def update_op(self, optimizer, g, *args):
    # pylint: disable=protected-access
    if isinstance(g, ops.IndexedSlices):
      if self._v.constraint is not None:
        raise RuntimeError(
            "Cannot use a constraint function on a sparse variable.")
      return optimizer._resource_apply_sparse_duplicate_indices(
          g.values, self._v, g.indices, *args)
    update_op = optimizer._resource_apply_dense(g, self._v, *args)
    if self._v.constraint is not None:
      with ops.control_dependencies([update_op]):
        return self._v.assign(self._v.constraint(self._v))
    else:
      return update_op


class _TensorProcessor(_OptimizableVariable):
  """Processor for ordinary Tensors.

  Even though a Tensor can't really be updated, sometimes it is useful to
  compute the gradients with respect to a Tensor using the optimizer. Updating
  the Tensor is, of course, unsupported.
  """

  def __init__(self, v):
    self._v = v

  def target(self):
    return self._v

  def update_op(self, optimizer, g, *args):
    raise NotImplementedError("Trying to update a Tensor ", self._v)


def _get_processor(v):
  """The processor of v."""
  if context.executing_eagerly():
    if isinstance(v, ops.Tensor):
      return _TensorProcessor(v)
    else:
      return _DenseResourceVariableProcessor(v)
  if v.op.type == "VarHandleOp":
    return _DenseResourceVariableProcessor(v)
  if isinstance(v, variables.Variable):
    return _RefVariableProcessor(v)
  if isinstance(v, ops.Tensor):
    return _TensorProcessor(v)
  raise NotImplementedError("Trying to optimize unsupported type ", v)


def _var_key_v2(var):
  """Key for representing a primary variable, for looking up slots."""
  # pylint: disable=protected-access
  if hasattr(var, "_mirrored_container"):
    mirrored_container = var._mirrored_container()
    assert mirrored_container is not None
    if context.executing_eagerly():
      return mirrored_container._unique_id
    return mirrored_container._shared_name
  if context.executing_eagerly():
    return var._unique_id
  return var.op.name


def _resolve(value, name):
  if callable(value):
    value = value()
  return ops.convert_to_tensor(value, name=name)


def _is_dynamic(value):
  """Returns true if __init__ arg `value` should be re-evaluated each step."""
  if callable(value): return True
  # Don't need to do anything special in graph mode, since dynamic values
  # will propagate correctly automatically.
  # TODO(josh11b): Add per-device caching across steps using variables for
  # truly static values once we add distributed support.
  if context.executing_eagerly() and isinstance(
      value, resource_variable_ops.ResourceVariable):
    return True
  return False


class _OptimizerV2State(object):
  """Holds per-graph and per-step optimizer state.

  Use _init_with_static_hyper() to create the state for a graph, and then
  _copy_with_dynamic_hyper() to convert that to state for a particular step.
  The difference between the two is that the former only has hyper
  parameter values that are static and the latter also has values that
  can change every step (according to _is_dynamic()).
  """

  def __init__(self, op_name):
    self._op_name = op_name

  def _init_with_static_hyper(self, hyper):
    """Initialize a fresh state object from hyper dict."""
    # self._hyper contains a dict from name to a dict with the Tensor values.
    # This dict starts with a single item with key "None" with the hyper
    # parameter value converted to a Tensor. Other items have dtype keys
    # with that Tensor cast to that dtype.
    self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
                   for name, (dynamic, value) in hyper.items() if not dynamic}
    self._slots = {}
    self._non_slot_dict = {}
    # Extra state to help Optimizers implement Checkpointable. Holds information
    # about variables which will be restored as soon as they're created.
    self._deferred_dependencies = {}  # Non-slot variables
    self._deferred_slot_restorations = {}  # Slot variables

  def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
    """Create a new state object for a particular step."""
    ret = _OptimizerV2State(self._op_name)
    # pylint: disable=protected-access
    ret._slots = self._slots
    ret._non_slot_dict = self._non_slot_dict
    ret._deferred_dependencies = self._deferred_dependencies
    ret._deferred_slot_restorations = self._deferred_slot_restorations
    ret._hyper = {name: {None: _resolve(value, name)}
                  for name, (dynamic, value) in hyper.items() if dynamic}
    ret._hyper.update(self._hyper)
    ret._non_slot_devices = non_slot_devices
    ret._distribution = distribution
    return ret

  def _variables(self):
    """Returns a list of all variables held by self."""
    optimizer_variables = list(self._non_slot_dict.values())
    for variable_dict in self._slots.values():
      for slot_for_variable in variable_dict.values():
        optimizer_variables.append(slot_for_variable)
    # Sort variables by name so that the return is deterministic.
    return sorted(optimizer_variables, key=lambda v: v.name)

  def _slot_dict(self, slot_name):
    """Returns a dict for caching slots created under the given name.

    Args:
      slot_name: Name for the slot.

    Returns:
      A dict that maps primary `Variable` objects to the slot created
      for that variable, under the given slot name.
    """
    named_slots = self._slots.get(slot_name, None)
    if named_slots is None:
      named_slots = {}
      self._slots[slot_name] = named_slots
    return named_slots

  def create_slot(self, var, val, slot_name, optional_op_name=None):
    """Find or create a slot for a variable.

    Args:
      var: A `Variable` object.
      val: A `Tensor`.  The initial value of the slot.
      slot_name: Name for the slot.
      optional_op_name: Name to use when scoping the Variable that
        needs to be created for the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    var_key = _var_key_v2(var)
    if var_key not in named_slots:
      new_slot_variable = slot_creator.create_slot(
          var, val, optional_op_name or self._op_name)
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=new_slot_variable)
      named_slots[var_key] = new_slot_variable
    return named_slots[var_key]

  def create_slot_with_initializer(self, var, initializer, shape, dtype,
                                   slot_name, optional_op_name=None):
    """Find or create a slot for a variable, using an Initializer.

    Args:
      var: A `Variable` object.
      initializer: An `Initializer`.  The initial value of the slot.
      shape: Shape of the initial value of the slot.
      dtype: Type of the value of the slot.
      slot_name: Name for the slot.
      optional_op_name: Name to use when scoping the Variable that
        needs to be created for the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    var_key = _var_key_v2(var)
    if var_key not in named_slots:
      new_slot_variable = slot_creator.create_slot_with_initializer(
          var, initializer, shape, dtype, optional_op_name or self._op_name)
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=new_slot_variable)
      named_slots[var_key] = new_slot_variable
    return named_slots[var_key]

  def zeros_slot(self, var, slot_name, optional_op_name=None):
    """Find or create a slot initialized with 0.0.

    Args:
      var: A `Variable` object.
      slot_name: Name for the slot.
      optional_op_name: Name to use when scoping the Variable that
        needs to be created for the slot.

    Returns:
      A `Variable` object.
    """
    named_slots = self._slot_dict(slot_name)
    var_key = _var_key_v2(var)
    if var_key not in named_slots:
      new_slot_variable = slot_creator.create_zeros_slot(
          var, optional_op_name or self._op_name)
      self._restore_slot_variable(
          slot_name=slot_name, variable=var,
          slot_variable=new_slot_variable)
      named_slots[var_key] = new_slot_variable
    return named_slots[var_key]

  def _create_or_restore_slot_variable(
      self, slot_variable_position, slot_name, variable,
      optional_op_name=None):
    """Restore a slot variable's value, possibly creating it.

    Called when a variable which has an associated slot variable is created or
    restored. When executing eagerly, we create the slot variable with a
    restoring initializer.

    No new variables are created when graph building. Instead,
    _restore_slot_variable catches these after normal creation and adds restore
    ops to the graph. This method is nonetheless important when graph building
    for the case when a slot variable has already been created but `variable`
    has just been added to a dependency graph (causing us to realize that the
    slot variable needs to be restored).

    Args:
      slot_variable_position: A `checkpointable._CheckpointPosition` object
        indicating the slot variable `Checkpointable` object to be restored.
      slot_name: The name of this `Optimizer`'s slot to restore into.
      variable: The variable object this slot is being created for.
      optional_op_name: Name to use when scoping the Variable that
        needs to be created for the slot.
    """
    slot_variable = self.get_slot(var=variable, name=slot_name)
    if (slot_variable is None and context.executing_eagerly() and
        slot_variable_position.is_simple_variable()):
      initializer = checkpointable.CheckpointInitialValue(
          checkpoint_position=slot_variable_position)
      slot_variable = self.create_slot(
          var=variable,
          val=initializer,
          slot_name=slot_name,
          optional_op_name=optional_op_name)
      # Optimizers do not have unconditional dependencies on their slot
      # variables (nor do any other objects). They are only saved if the
      # variables they were created for are also saved.
    if slot_variable is not None:
      # If we've either made this slot variable, or if we've pulled out an
      # existing slot variable, we should restore it.
      slot_variable_position.restore(slot_variable)
    else:
      # We didn't make the slot variable. Defer restoring until it gets created
      # normally. We keep a list rather than the one with the highest restore
      # UID in case slot variables have their own dependencies, in which case
      # those could differ between restores.
      variable_key = _var_key_v2(variable)
      self._deferred_slot_restorations.setdefault(
          slot_name, {}).setdefault(variable_key, []).append(
              slot_variable_position)

  def get_slot(self, var, name):
    """Return a slot named `name` created for `var` by the Optimizer.

    Some `Optimizer` subclasses use additional variables.  For example
    `Momentum` and `Adagrad` use variables to accumulate updates.  This method
    gives access to these `Variable` objects if for some reason you need them.

    Use `get_slot_names()` to get the list of slot names created by the
    `Optimizer`.

    Args:
      var: A variable passed to `minimize()` or `apply_gradients()`.
      name: A string.

    Returns:
      The `Variable` for the slot if it was created, `None` otherwise.
    """
    named_slots = self._slots.get(name, None)
    if not named_slots:
      return None
    return named_slots.get(_var_key_v2(var), None)

  def get_slot_names(self):
    """Return a list of the names of slots created by the `Optimizer`.

    See `get_slot()`.

    Returns:
      A list of strings.
    """
    return sorted(self._slots.keys())

  def create_non_slot(self, initial_value, name, colocate_with=None):
    """Add an extra variable, not associated with a slot."""
    v = self._non_slot_dict.get(name, None)
    if v is None:
      if colocate_with is None: colocate_with = self._non_slot_devices
      with self._distribution.colocate_vars_with(colocate_with):
        # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
        v = variable_scope.variable(initial_value, name=name, trainable=False)
      self._non_slot_dict[name] = v
      deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
      for checkpoint_position in sorted(
          deferred_dependencies_list,
          key=lambda restore: restore.checkpoint.restore_uid,
          reverse=True):
        checkpoint_position.restore(v)
    return v

  def _restore_slot_variable(self, slot_name, variable, slot_variable):
    """Restore a newly created slot variable's value."""
    variable_key = _var_key_v2(variable)
    deferred_restorations = self._deferred_slot_restorations.get(
        slot_name, {}).pop(variable_key, [])
    # Iterate over restores, highest restore UID first to minimize the number
    # of assignments.
    deferred_restorations.sort(key=lambda position: position.restore_uid,
                               reverse=True)
    for checkpoint_position in deferred_restorations:
      checkpoint_position.restore(slot_variable)

  def get_non_slot(self, name):
    """Returns the non-slot variable identified by `name`."""
    return self._non_slot_dict.get(name, None)

  def get_hyper(self, name, dtype=None):
    """Returns the `name` hyper parameter, optionally cast to `dtype`."""
    dtype_dict = self._hyper[name]
    # Do we have the value cast to dtype already cached? This should always
    # succeed when dtype is None.
    if dtype in dtype_dict:
      return dtype_dict[dtype]
    # Not cached, cast to dtype and save the result in the cache.
    result = math_ops.cast(dtype_dict[None], dtype)
    dtype_dict[dtype] = result
    return result


class OptimizerV2(optimizer_v1.Optimizer):
  """Updated base class for optimizers.

  This class defines the API to add Ops to train a model.  You never use this
  class directly, but instead instantiate one of its subclasses such as
  `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.

  ### Usage

  ```python
  # Create an optimizer with the desired parameters.
  opt = GradientDescentOptimizer(learning_rate=0.1)
  # Add Ops to the graph to minimize a cost by updating a list of variables.
  # "cost" is a Tensor, and the list of variables contains tf.Variable
  # objects.
  opt_op = opt.minimize(cost, var_list=<list of variables>)
  ```

  In the training program you will just have to run the returned Op.

  ```python
  # Execute opt_op to do one step of training:
  opt_op.run()
  ```

  ### Processing gradients before applying them.

  Calling `minimize()` takes care of both computing the gradients and
  applying them to the variables.  If you want to process the gradients
  before applying them you can instead use the optimizer in three steps:

  1.  Compute the gradients with `compute_gradients()`.
  2.  Process the gradients as you wish.
  3.  Apply the processed gradients with `apply_gradients()`.

  Example:

  ```python
  # Create an optimizer.
  opt = GradientDescentOptimizer(learning_rate=0.1)

  # Compute the gradients for a list of variables.
  grads_and_vars = opt.compute_gradients(loss, <list of variables>)

  # grads_and_vars is a list of tuples (gradient, variable).  Do whatever you
  # need to the 'gradient' part, for example cap them, etc.
  capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]

  # Ask the optimizer to apply the capped gradients.
  opt.apply_gradients(capped_grads_and_vars)
  ```

  ### Gating Gradients

  Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
  argument that controls the degree of parallelism during the application of
  the gradients.

  The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.

  <b>`GATE_NONE`</b>: Compute and apply gradients in parallel.  This provides
  the maximum parallelism in execution, at the cost of some non-reproducibility
  in the results.  For example the two gradients of `matmul` depend on the input
  values: With `GATE_NONE` one of the gradients could be applied to one of the
  inputs _before_ the other gradient is computed resulting in non-reproducible
  results.

  <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
  they are used.  This prevents race conditions for Ops that generate gradients
  for multiple inputs where the gradients depend on the inputs.

  <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
  before any one of them is used.  This provides the least parallelism but can
  be useful if you want to process all gradients before applying any of them.

  ### Slots

  Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
  allocate and manage additional variables associated with the variables to
  train.  These are called <i>Slots</i>.  Slots have names and you can ask the
  optimizer for the names of the slots that it uses.  Once you have a slot name
  you can ask the optimizer for the variable it created to hold the slot value.

  This can be useful if you want to log debug a training algorithm, report stats
  about the slots, etc.

  ### Non-slot variables

  Some optimizer subclasses, such as `AdamOptimizer` have variables that
  are not associated with the variables to train, just the step itself.

  ### Hyper parameters

  These are arguments passed to the optimizer subclass constructor
  (the `__init__` method), and then passed to `self._set_hyper()`.
  They can be either regular Python values (like 1.0), tensors, or
  callables. If they are callable, the callable will be called during
  `apply_gradients()` to get the value for the hyper parameter.

  ### State

  Internal methods apre passed a `state` argument with the correct
  values to use for the slot and non-slot variables, and the hyper
  parameters.
  """

  # Values for gate_gradients.
  GATE_NONE = 0
  GATE_OP = 1
  GATE_GRAPH = 2

  def __init__(self, use_locking, name):
    """Create a new Optimizer.

    This must be called by the constructors of subclasses.
    Note that Optimizer instances should not bind to a single graph,
    and so shouldn't keep Tensors as member variables. Generally
    you should be able to use the _set_hyper()/state.get_hyper()
    facility instead.

    Args:
      use_locking: Bool. If True apply use locks to prevent concurrent updates
        to variables.
      name: A non-empty string.  The name to use for accumulators created
        for the optimizer.

    Raises:
      ValueError: If name is malformed.
      RuntimeError: If _create_slots has been overridden instead of
          _create_vars.
    """
    # Note: We intentionally don't call parent __init__.

    # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
    if (self.__class__._create_slots.__code__ is not  # pylint: disable=protected-access
        OptimizerV2._create_slots.__code__):
      raise RuntimeError("Override _create_vars instead of _create_slots when "
                         "descending from OptimizerV2 (class %s)" %
                         self.__class__.__name__)
    if not name:
      raise ValueError("Must specify the optimizer name")

    self._use_locking = use_locking
    self._name = name
    # Map from graph_key to state for that graph. We use the graph_key
    # since it works in both eager and graph mode, and gives the outer
    # graph inside functions.
    tower_context = distribute_lib.get_tower_context()
    if tower_context is None:
      # In a cross-tower context for a DistributionStrategy, which means
      # only one Optimizer will be created, not one per tower.
      self._per_graph_state = {}
    else:
      # We use get_tower_context().merge_call() to get a single dict
      # shared across all model replicas when running with a
      # DistributionStrategy.
      self._per_graph_state = tower_context.merge_call(lambda _: {})

    # Hyper parameters, and whether they should be re-evaluated every step.
    self._hyper = {}

  def _set_hyper(self, name, value):
    self._hyper[name] = (_is_dynamic(value), value)

  def minimize(self, loss, global_step=None, var_list=None,
               gate_gradients=GATE_OP, aggregation_method=None,
               colocate_gradients_with_ops=False, name=None,
               grad_loss=None, stop_gradients=None,
               scale_loss_by_num_towers=None):
    """Add operations to minimize `loss` by updating `var_list`.

    This method simply combines calls `compute_gradients()` and
    `apply_gradients()`. If you want to process the gradient before applying
    them call `compute_gradients()` and `apply_gradients()` explicitly instead
    of using this function.

    Args:
      loss: A `Tensor` containing the value to minimize.
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      var_list: Optional list or tuple of `Variable` objects to update to
        minimize `loss`.  Defaults to the list of variables collected in
        the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
      gate_gradients: How to gate the computation of gradients.  Can be
        `GATE_NONE`, `GATE_OP`, or  `GATE_GRAPH`.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: If True, try colocating gradients with
        the corresponding op.
      name: Optional name for the returned operation.
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
      stop_gradients: Optional. A Tensor or list of tensors not to differentiate
        through.
      scale_loss_by_num_towers: Optional boolean. If true, scale the loss
        down by the number of towers. By default, auto-detects whether this
        is needed.

    Returns:
      An Operation that updates the variables in `var_list`.  If `global_step`
      was not `None`, that operation also increments `global_step`.

    Raises:
      ValueError: If some of the variables are not `Variable` objects.

    @compatibility(eager)
    When eager execution is enabled, `loss` should be a Python function that
    takes elements of `var_list` as arguments and computes the value to be
    minimized. If `var_list` is None, `loss` should take no arguments.
    Minimization (and gradient computation) is done with respect to the
    elements of `var_list` if not None, else with respect to any trainable
    variables created during the execution of the `loss` function.
    `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
    `grad_loss` are ignored when eager execution is enabled.
    @end_compatibility
    """
    grads_and_vars = self.compute_gradients(
        loss, var_list=var_list, gate_gradients=gate_gradients,
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops,
        grad_loss=grad_loss, stop_gradients=stop_gradients,
        scale_loss_by_num_towers=scale_loss_by_num_towers)

    vars_with_grad = [v for g, v in grads_and_vars if g is not None]
    if not vars_with_grad:
      raise ValueError(
          "No gradients provided for any variable, check your graph for ops"
          " that do not support gradients, between variables %s and loss %s." %
          ([str(v) for _, v in grads_and_vars], loss))

    return self.apply_gradients(grads_and_vars, global_step=global_step,
                                name=name)

  def compute_gradients(self, loss, var_list=None,
                        gate_gradients=GATE_OP,
                        aggregation_method=None,
                        colocate_gradients_with_ops=False,
                        grad_loss=None, stop_gradients=None,
                        scale_loss_by_num_towers=None):
    """Compute gradients of `loss` for the variables in `var_list`.

    This is the first part of `minimize()`.  It returns a list
    of (gradient, variable) pairs where "gradient" is the gradient
    for "variable".  Note that "gradient" can be a `Tensor`, an
    `IndexedSlices`, or `None` if there is no gradient for the
    given variable.

    Args:
      loss: A Tensor containing the value to minimize or a callable taking
        no arguments which returns the value to minimize. When eager execution
        is enabled it must be a callable.
      var_list: Optional list or tuple of `tf.Variable` to update to minimize
        `loss`.  Defaults to the list of variables collected in the graph
        under the key `GraphKeys.TRAINABLE_VARIABLES`.
      gate_gradients: How to gate the computation of gradients.  Can be
        `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
      aggregation_method: Specifies the method used to combine gradient terms.
        Valid values are defined in the class `AggregationMethod`.
      colocate_gradients_with_ops: If True, try colocating gradients with
        the corresponding op.
      grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
      stop_gradients: Optional. A Tensor or list of tensors not to differentiate
        through.
      scale_loss_by_num_towers: Optional boolean. If true, scale the loss
        down by the number of towers. By default, auto-detects whether this
        is needed.

    Returns:
      A list of (gradient, variable) pairs. Variable is always present, but
      gradient can be `None`.

    Raises:
      TypeError: If `var_list` contains anything else than `Variable` objects.
      ValueError: If some arguments are invalid.
      RuntimeError: If called with eager execution enabled and `loss` is
        not callable.

    @compatibility(eager)
    When eager execution is enabled, `gate_gradients`, `aggregation_method`,
    and `colocate_gradients_with_ops` are ignored.
    @end_compatibility
    """
    # TODO(josh11b): Test that we handle weight decay in a reasonable way.
    if callable(loss):
      with backprop.GradientTape() as tape:
        if var_list is not None:
          tape.watch(var_list)
        loss_value = loss()

        # Scale loss for number of towers (callable-loss case). In this case,
        # we have to be careful to call distribute_lib.get_loss_reduction()
        # *after* loss() is evaluated, so we know what loss reduction it uses.
        if scale_loss_by_num_towers is None:
          scale_loss_by_num_towers = (
              distribute_lib.get_loss_reduction() == "mean")
        if scale_loss_by_num_towers:
          num_towers = distribute_lib.get_distribution_strategy().num_towers
          if num_towers > 1:
            loss_value *= 1. / num_towers

      if var_list is None:
        var_list = tape.watched_variables()
      grads = tape.gradient(loss_value, var_list, grad_loss)
      return list(zip(grads, var_list))
    if context.executing_eagerly():
      raise RuntimeError(
          "`loss` passed to Optimizer.compute_gradients should "
          "be a function when eager execution is enabled.")

    # Scale loss for number of towers (non-callable-loss case).
    if scale_loss_by_num_towers is None:
      scale_loss_by_num_towers = (
          distribute_lib.get_loss_reduction() == "mean")
    if scale_loss_by_num_towers:
      num_towers = distribute_lib.get_distribution_strategy().num_towers
      if num_towers > 1:
        loss *= 1. / num_towers

    if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
                              optimizer_v1.Optimizer.GATE_OP,
                              optimizer_v1.Optimizer.GATE_GRAPH]:
      raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
                       "Optimizer.GATE_OP, Optimizer.GATE_GRAPH.  Not %s" %
                       gate_gradients)
    self._assert_valid_dtypes([loss])
    if grad_loss is not None:
      self._assert_valid_dtypes([grad_loss])
    if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    else:
      var_list = nest.flatten(var_list)
    # pylint: disable=protected-access
    var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
    # pylint: enable=protected-access
    processors = [_get_processor(v) for v in var_list]
    if not var_list:
      raise ValueError("No variables to optimize.")
    var_refs = [p.target() for p in processors]
    grads = gradients.gradients(
        loss, var_refs, grad_ys=grad_loss,
        gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
        aggregation_method=aggregation_method,
        colocate_gradients_with_ops=colocate_gradients_with_ops,
        stop_gradients=stop_gradients)
    if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
      grads = control_flow_ops.tuple(grads)
    grads_and_vars = list(zip(grads, var_list))
    self._assert_valid_dtypes(
        [v for g, v in grads_and_vars
         if g is not None and v.dtype != dtypes.resource])
    return grads_and_vars

  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    This is the second part of `minimize()`. It returns an `Operation` that
    applies gradients.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        `compute_gradients()`.
      global_step: Optional `Variable` to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the `Optimizer` constructor.

    Returns:
      An `Operation` that applies the specified gradients. If `global_step`
      was not None, that operation also increments `global_step`.

    Raises:
      TypeError: If `grads_and_vars` is malformed.
      ValueError: If none of the variables have gradients.
    """
    # This is a default implementation of apply_gradients() that can be shared
    # by most optimizers.  It relies on the subclass implementing the following
    # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().

    # Filter out variables with gradients of `None`.
    grads_and_vars = tuple(grads_and_vars)  # Make sure repeat iteration works.
    if not grads_and_vars:
      raise ValueError("No variables provided.")
    filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
    if not filtered:
      raise ValueError("No gradients provided for any variable: %s." %
                       ([str(v) for _, v in grads_and_vars],))
    return distribute_lib.get_tower_context().merge_call(
        self._distributed_apply, filtered, global_step=global_step, name=name)

  def _get_or_create_state(self, var_list=None):
    """Either looks up or creates `_OptimizerV2State`.

    If any variables are available, they should be passed via the `var_list`
    argument, and these will be used to determine the graph to create/retrieve
    state for. Otherwise the returned state is for the current default graph.

    Args:
      var_list: A list of variables to extract a graph from.

    Returns:
      An `_OptimizerV2State` object.
    """
    # Determine the graph_key from the current graph.
    eager_execution = context.executing_eagerly()
    if eager_execution or var_list is None:
      graph = ops.get_default_graph()
    else:
      graph = ops._get_graph_from_inputs(var_list)  # pylint: disable=protected-access
    assert graph is not None
    graph_key = graph._graph_key  # pylint: disable=protected-access

    # Get the per graph state by looking up the graph_key.
    if graph_key in self._per_graph_state:
      per_graph_state = self._per_graph_state[graph_key]
    else:
      per_graph_state = _OptimizerV2State(self._name)
      per_graph_state._init_with_static_hyper(self._hyper)  # pylint: disable=protected-access
      self._per_graph_state[graph_key] = per_graph_state
    return per_graph_state

  def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
    """`apply_gradients` for use with a `DistributionStrategy`."""
    reduced_grads = distribution.batch_reduce("sum", grads_and_vars)
    var_list = [v for _, v in grads_and_vars]
    grads_and_vars = zip(reduced_grads, var_list)

    unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
    eager_execution = context.executing_eagerly()
    if eager_execution:
      # Give a clear error in this case instead of "name not supported
      # for Eager Tensors" when we compute non_slot_devices.
      for v in unwrapped_var_list:
        if isinstance(v, ops.Tensor):
          raise NotImplementedError("Trying to update a Tensor ", v)

    with ops.name_scope(name, self._name) as name:
      per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
      # Include the current value of any dynamic hyper parameters in `state`.
      non_slot_devices = distribution.non_slot_devices(var_list)
      state = per_graph_state._copy_with_dynamic_hyper(  # pylint: disable=protected-access
          self._hyper, distribution, non_slot_devices)

    # Create any slot and non-slot variables we need in `state`.
    with ops.init_scope():
      self._create_vars(var_list, state)

    with ops.name_scope(name):  # Re-enter name_scope created above
      # Give the child class a chance to do something before we start
      # applying gradients.
      self._prepare(state)

      def update(v, g):
        """Update variable `v` using gradient `g`."""
        assert v is not None

        # Convert the grad to Tensor or IndexedSlices if necessary, and
        # look up a processor for each variable's type.
        try:
          g = ops.convert_to_tensor_or_indexed_slices(g)
        except TypeError:
          raise TypeError(
              "Gradient must be convertible to a Tensor"
              " or IndexedSlices, or None: %s" % g)
        if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
          raise TypeError(
              "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
        processor = _get_processor(v)

        # We colocate all ops created in _apply_dense or _apply_sparse
        # on the same device as the variable.
        # TODO(apassos): figure out how to get the variable name here.
        scope_name = "" if eager_execution else v.op.name
        # device_policy is set because non-mirrored tensors will be read in
        # `update_op`.
        # TODO(josh11b): Make different state objects for each device to
        # avoid needing to set the device_policy.
        with ops.name_scope("update_" + scope_name), \
            context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
          return processor.update_op(self, g, state)

      # Use the processors to update the variables.
      update_ops = []
      for grad, var in grads_and_vars:
        update_ops.extend(distribution.unwrap(distribution.update(
            var, update, grad)))

      # Give the child class a chance to do something after applying
      # gradients
      def finish():
        # TODO(josh11b): Make different state objects for each device to
        # avoid needing to set the device_policy.
        with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
          return self._finish(state)

      update_ops = control_flow_ops.group(update_ops)
      with ops.control_dependencies([update_ops]):
        finish_updates = distribution.update_non_slot(non_slot_devices, finish)
      if finish_updates is None:
        finish_updates = update_ops

      # Update `global_step` (if any).
      if global_step is None:
        apply_updates = distribution.group(finish_updates, name=name)
      else:
        with ops.control_dependencies(distribution.unwrap(finish_updates)):

          def update_global_step(global_step):
            if isinstance(global_step, resource_variable_ops.ResourceVariable):
              return global_step.assign_add(
                  ops.convert_to_tensor(1, dtype=global_step.dtype),
                  read_value=False)
            else:
              return state_ops.assign_add(global_step, 1)

          apply_updates = distribution.group(
              distribution.update(global_step, update_global_step), name=name)

      # Add the training op to the TRAIN_OP graph collection in graph mode.
      if not eager_execution:
        if isinstance(apply_updates, ops.Tensor):
          apply_updates = apply_updates.op
        train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
        if apply_updates not in train_op:
          train_op.append(apply_updates)

      return apply_updates

  def get_slot(self, var, name):
    """Return a slot named `name` created for `var` by the Optimizer.

    Some `Optimizer` subclasses use additional variables.  For example
    `Momentum` and `Adagrad` use variables to accumulate updates.  This method
    gives access to these `Variable` objects if for some reason you need them.

    Use `get_slot_names()` to get the list of slot names created by the
    `Optimizer`.

    Args:
      var: A variable passed to `minimize()` or `apply_gradients()`.
      name: A string.

    Returns:
      The `Variable` for the slot if it was created, `None` otherwise.
    """
    state = self._get_state_for_var(var)
    return state.get_slot(var, name) if state is not None else None

  def get_slot_names(self):
    """Return a list of the names of slots created by the `Optimizer`.

    See `get_slot()`.

    Returns:
      A list of strings.
    """
    state = self._get_per_graph_state()
    return state.get_slot_names() if state is not None else []

  def variables(self):
    """A list of variables which encode the current state of `Optimizer`.

    Includes slot variables and additional global variables created by the
    optimizer in the current default graph.

    Returns:
      A list of variables.
    """
    state = self._get_per_graph_state()
    return state._variables() if state is not None else []  # pylint: disable=protected-access

  # --------------
  # Methods to be implemented by subclasses if they want to use the
  # inherited implementation of apply_gradients() or compute_gradients().
  # --------------
  def _create_vars(self, var_list, state):
    """Create all slots needed by the variables and any non-slot variables.

    Args:
      var_list: A list of `Variable` objects.
      state: An object with these methods:
        `create_slot(var, val, slot_name, optional_op_name)`,
        `create_slot_with_initializer(`
            `var, initializer, shape, dtype, slot_name, optional_op_name)`,
        `zeros_slot(var, slot_name, optional_op_name)`,
        `create_non_slot_variable(initial_value, name, colocate_with)`,
        `get_hyper(name)`
    """
    # No slots needed by default
    pass

  def _prepare(self, state):
    """Code to execute before applying gradients.

    Note that most uses of _prepare() in Optimizer have been subsumed
    by explicit support for hyper parameters in OptimizerV2

    Args:
      state: An object with a `get_hyper(name)` method.

    Returns:
      Return value will be ignored.
    """
    pass

  def _apply_dense(self, grad, var, state):
    """Add ops to apply dense gradients to `var`.

    Args:
      grad: A `Tensor`.
      var: A `Variable` object.
      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
        and `get_hyper(name)` methods.

    Returns:
      An `Operation`.
    """
    raise NotImplementedError()

  def _resource_apply_dense(self, grad, handle, state):
    """Add ops to apply dense gradients to the variable `handle`.

    Args:
      grad: a `Tensor` representing the gradient.
      handle: a `Tensor` of dtype `resource` which points to the variable
       to be updated.
      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
        and `get_hyper(name)` methods.

    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

  def _resource_apply_sparse_duplicate_indices(
      self, grad, handle, indices, state):
    """Add ops to apply sparse gradients to `handle`, with repeated indices.

    Optimizers which override this method must deal with repeated indices. See
    the docstring of `_apply_sparse_duplicate_indices` for details. By default
    the correct behavior, to sum non-unique indices and their associated
    gradients, is enforced by first pre-processing `grad` and `indices` and
    passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
    with duplicate indices may instead override this method to avoid the
    overhead of summing.

    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable
       to be updated.
      indices: a `Tensor` of integral type representing the indices for
       which the gradient is nonzero. Indices may be repeated.
      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
        and `get_hyper(name)` methods.

    Returns:
      An `Operation` which updates the value of the variable.
    """
    # pylint: disable=protected-access
    summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
        values=grad, indices=indices)
    # pylint: enable=protected-access
    return self._resource_apply_sparse(
        summed_grad, handle, unique_indices, state)

  def _resource_apply_sparse(self, grad, handle, indices, state):
    """Add ops to apply sparse gradients to the variable `handle`.

    Similar to `_apply_sparse`, the `indices` argument to this method has been
    de-duplicated. Optimizers which deal correctly with non-unique indices may
    instead override `_resource_apply_sparse_duplicate_indices` to avoid this
    overhead.

    Args:
      grad: a `Tensor` representing the gradient for the affected indices.
      handle: a `Tensor` of dtype `resource` which points to the variable
       to be updated.
      indices: a `Tensor` of integral type representing the indices for
       which the gradient is nonzero. Indices are unique.
      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
        and `get_hyper(name)` methods.

    Returns:
      An `Operation` which updates the value of the variable.
    """
    raise NotImplementedError()

  def _apply_sparse_duplicate_indices(self, grad, var, state):
    """Add ops to apply sparse gradients to `var`, with repeated sparse indices.

    Optimizers which override this method must deal with IndexedSlices objects
    such as the following:

      IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])

    The correct interpretation is:

      IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])

    Many optimizers deal incorrectly with repeated indices when updating based
    on sparse gradients (e.g. summing squares rather than squaring the sum, or
    applying momentum terms multiple times). Adding first is always the correct
    behavior, so this is enforced here by reconstructing the IndexedSlices to
    have only unique indices, then calling _apply_sparse.

    Optimizers which deal correctly with repeated indices may instead override
    this method to avoid the overhead of summing indices.

    Args:
      grad: `IndexedSlices`.
      var: A `Variable` object.
      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
        and `get_hyper(name)` methods.

    Returns:
      An `Operation`.
    """
    # pylint: disable=protected-access
    summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
        values=grad.values, indices=grad.indices)
    # pylint: enable=protected-access
    gradient_no_duplicate_indices = ops.IndexedSlices(
        indices=unique_indices,
        values=summed_values,
        dense_shape=grad.dense_shape)
    return self._apply_sparse(gradient_no_duplicate_indices, var, state)

  def _apply_sparse(self, grad, var, state):
    """Add ops to apply sparse gradients to `var`.

    The IndexedSlices object passed to `grad` in this function is by default
    pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
    indices (see its docstring for details). Optimizers which can tolerate or
    have correct special cases for duplicate sparse indices may override
    `_apply_sparse_duplicate_indices` instead of this function, avoiding that
    overhead.

    Args:
      grad: `IndexedSlices`, with no repeated indices.
      var: A `Variable` object.
      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
        and `get_hyper(name)` methods.

    Returns:
      An `Operation`.
    """
    raise NotImplementedError()

  def _finish(self, state):
    """Do what is needed to finish the update.

    This is called inside a scope colocated with any non-slot variables.

    Args:
      state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
        and `get_hyper(name)` methods.

    Returns:
      The operation to apply updates, or None if no updates.
    """
    return None

  # --------------
  # Utility methods for subclasses.
  # --------------
  def _get_per_graph_state(self):
    # pylint: disable=protected-access
    return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)

  def _get_state_for_var(self, var):
    # pylint: disable=protected-access
    return self._per_graph_state.get(var._graph_key, None)

  # --------------
  # Overridden methods from Checkpointable.
  # --------------

  def _track_checkpointable(self, *args, **kwargs):
    """Optimizers may not track dependencies. Raises an error."""
    raise NotImplementedError(
        "Optimizers may not have dependencies. File a feature request if this "
        "limitation bothers you.")

  @property
  def _checkpoint_dependencies(self):
    """From Checkpointable. Gather graph-specific non-slot variables to save."""
    current_graph_non_slot_variables = []
    state = self._get_per_graph_state()
    if state is not None:
      for name, variable_object in sorted(
          state._non_slot_dict.items(),  # pylint: disable=protected-access
          # Avoid comparing variables
          key=lambda item: item[0]):
        current_graph_non_slot_variables.append(
            checkpointable.CheckpointableReference(
                name=name, ref=variable_object))
    # Note: ignores super(); Optimizers may not have any dependencies outside of
    # state objects.
    return current_graph_non_slot_variables

  def _lookup_dependency(self, name):
    """From Checkpointable. Find a non-slot variable in the current graph."""
    state = self._get_per_graph_state()
    if state is None:
      return None
    else:
      return state.get_non_slot(name)

  @property
  def _deferred_dependencies(self):
    """Lets Checkpointable know where non-slot variables are created.

    If necessary, creates a new state object for the current default graph.
    Checkpointable will then add entries to that state's deferred dependency
    dictionary. The state object will check that dictionary when creating
    non-slot variables, restoring their value if an entry is found.

    Returns:
      A dictionary which holds deferred dependencies for the current default
      graph.
    """
    state = self._get_or_create_state()
    return state._deferred_dependencies  # pylint: disable=protected-access

  def _create_or_restore_slot_variable(
      self, slot_variable_position, slot_name, variable):
    """Checkpointable: Restore a slot variable's value, possibly creating it.

    Called when a variable which has an associated slot variable is created or
    restored.

    Args:
      slot_variable_position: A `checkpointable._CheckpointPosition` object
        indicating the slot variable `Checkpointable` object to be restored.
      slot_name: The name of this `Optimizer`'s slot to restore into.
      variable: The variable object this slot is being created for.
    """
    state = self._get_or_create_state(var_list=[variable])
    state._create_or_restore_slot_variable(  # pylint: disable=protected-access
        slot_variable_position=slot_variable_position,
        slot_name=slot_name,
        variable=variable,
        optional_op_name=self._name)

  # --------------
  # Unsupported parent methods
  # --------------
  def _slot_dict(self, slot_name):
    raise NotImplementedError(
        "_slot_dict() method unsupported in OptimizerV2")

  def _get_or_make_slot(self, var, val, slot_name, op_name):
    raise NotImplementedError(
        "_get_or_make_slot() method unsupported in OptimizerV2")

  def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
                                         slot_name, op_name):
    raise NotImplementedError(
        "_get_or_make_slot_with_initializer() method unsupported in "
        "OptimizerV2")

  def _create_non_slot_variable(self, initial_value, name, colocate_with):
    raise NotImplementedError(
        "_create_non_slot_variable() method unsupported in OptimizerV2")

  def _get_non_slot_variable(self, name, graph=None):
    raise NotImplementedError(
        "_get_non_slot_variable() method unsupported in OptimizerV2")

  def _non_slot_variables(self):
    raise NotImplementedError(
        "_non_slot_variables() method unsupported in OptimizerV2")