aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/training_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/training_ops.py')
-rw-r--r--tensorflow/python/training/training_ops.py17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py
index 8619752338..1a96f77c1c 100644
--- a/tensorflow/python/training/training_ops.py
+++ b/tensorflow/python/training/training_ops.py
@@ -170,6 +170,23 @@ def _SparseApplyProximalGradientDescentShape(op):
return [var_shape]
+@ops.RegisterShape("SparseApplyRMSProp")
+def _SparseApplyRMSPropShape(op):
+ """Shape function for the SparseApplyRMSProp op."""
+ var_shape = op.inputs[0].get_shape()
+ ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
+ _AssertInputIsScalar(op, 3) # lr
+ _AssertInputIsScalar(op, 4) # rho
+ _AssertInputIsScalar(op, 5) # momentum
+ _AssertInputIsScalar(op, 6) # epsilon
+ grad_shape = op.inputs[7].get_shape().merge_with(
+ tensor_shape.TensorShape([None]).concatenate(mom_shape[1:]))
+ unused_indices_shape = op.inputs[8].get_shape().merge_with(
+ tensor_shape.vector(grad_shape[0]))
+ return [mom_shape]
+
+
@ops.RegisterShape("SparseApplyAdadelta")
def _SparseApplyAdadeltaShape(op):
"""Shape function for the SparseApplyAdadelta op."""