diff options
Diffstat (limited to 'tensorflow/python/training/optimizer.py')
-rw-r--r-- | tensorflow/python/training/optimizer.py | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index 9e029389f2..e862cb87b1 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -34,6 +34,8 @@ from tensorflow.python.training import slot_creator def _var_key(var): + if var.op.type == "ResourceGather": + var = var.op.inputs[0] return (var.op.graph, var.op.name) @@ -530,6 +532,33 @@ class Optimizer(object): raise NotImplementedError() def _resource_apply_dense(self, grad, handle): + """Add ops to apply dense gradients to the variable `handle`. + + Args: + grad: a `Tensor` representing the gradient. + handle: a `Tensor` of dtype `resource` which points to the variable + to be updated. + + Returns: + An `Operation` which updates the value of the variable. + """ + raise NotImplementedError() + + def _resource_apply_sparse(self, grad, handle, indices): + """Add ops to apply sparse gradients to the variable `handle`. + + + Args: + grad: a `Tensor` representing the gradient for the affected indices. + handle: a `Tensor` of dtype `resource` which points to the variable + to be updated. + indices: a `Tensor` of integral type representing the indices for + which the gradient is nonzero. + + Returns: + An `Operation` which updates the value of the variable. + + """ raise NotImplementedError() def _apply_sparse_duplicate_indices(self, grad, var): |