aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/optimizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/optimizer.py')
-rw-r--r--tensorflow/python/training/optimizer.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 9e029389f2..e862cb87b1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -34,6 +34,8 @@ from tensorflow.python.training import slot_creator
def _var_key(var):
+ if var.op.type == "ResourceGather":
+ var = var.op.inputs[0]
return (var.op.graph, var.op.name)
@@ -530,6 +532,33 @@ class Optimizer(object):
raise NotImplementedError()
def _resource_apply_dense(self, grad, handle):
+ """Add ops to apply dense gradients to the variable `handle`.
+
+ Args:
+ grad: a `Tensor` representing the gradient.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_sparse(self, grad, handle, indices):
+ """Add ops to apply sparse gradients to the variable `handle`.
+
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+
+ """
raise NotImplementedError()
def _apply_sparse_duplicate_indices(self, grad, var):