diff options
Diffstat (limited to 'tensorflow/python/training/training_ops.py')
-rw-r--r-- | tensorflow/python/training/training_ops.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py index 8619752338..1a96f77c1c 100644 --- a/tensorflow/python/training/training_ops.py +++ b/tensorflow/python/training/training_ops.py @@ -170,6 +170,23 @@ def _SparseApplyProximalGradientDescentShape(op): return [var_shape] +@ops.RegisterShape("SparseApplyRMSProp") +def _SparseApplyRMSPropShape(op): + """Shape function for the SparseApplyRMSProp op.""" + var_shape = op.inputs[0].get_shape() + ms_shape = op.inputs[1].get_shape().merge_with(var_shape) + mom_shape = op.inputs[2].get_shape().merge_with(ms_shape) + _AssertInputIsScalar(op, 3) # lr + _AssertInputIsScalar(op, 4) # rho + _AssertInputIsScalar(op, 5) # momentum + _AssertInputIsScalar(op, 6) # epsilon + grad_shape = op.inputs[7].get_shape().merge_with( + tensor_shape.TensorShape([None]).concatenate(mom_shape[1:])) + unused_indices_shape = op.inputs[8].get_shape().merge_with( + tensor_shape.vector(grad_shape[0])) + return [mom_shape] + + @ops.RegisterShape("SparseApplyAdadelta") def _SparseApplyAdadeltaShape(op): """Shape function for the SparseApplyAdadelta op.""" |