diff options
Diffstat (limited to 'tensorflow/python/training/training_ops.py')
-rw-r--r-- | tensorflow/python/training/training_ops.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py new file mode 100644 index 0000000000..410b23e04d --- /dev/null +++ b/tensorflow/python/training/training_ops.py @@ -0,0 +1,115 @@ +"""Python wrappers for training ops.""" + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.training import gen_training_ops +# pylint: disable=wildcard-import +from tensorflow.python.training.gen_training_ops import * +# pylint: enable=wildcard-import + + +# Shape functions for fused training ops +# -------------------------------------- +# +# The fused training ops all have the same basic structure: they take +# one or more variables with the same shape, and emit a reference to +# the original variable (which has the same shape as the first +# input). In addition, they take one or more scalar tensors containing +# hyperparameters. +# +# The sparse ops take the gradients as a Python IndexedSlices, which +# means that the indices are a vector of length N, and the gradient +# values are a tensor whose size is the same as the original variable, +# except for the 0th dimension, which has size N. + + +def _AssertInputIsScalar(op, index): + """Raises ValueError if `op.inputs[index]` is not scalar.""" + op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar()) + + +@ops.RegisterShape("ApplyAdagrad") +def _ApplyAdagradShape(op): + """Shape function for the ApplyAdagrad op.""" + var_shape = op.inputs[0].get_shape() + accum_shape = op.inputs[1].get_shape().merge_with(var_shape) + _AssertInputIsScalar(op, 2) # lr + grad_shape = op.inputs[3].get_shape().merge_with(accum_shape) + return [grad_shape] + + +@ops.RegisterShape("ApplyAdam") +def _ApplyAdamShape(op): + """Shape function for the ApplyAdam op.""" + var_shape = op.inputs[0].get_shape() + m_shape = op.inputs[1].get_shape().merge_with(var_shape) + v_shape = op.inputs[2].get_shape().merge_with(m_shape) + _AssertInputIsScalar(op, 3) # beta1_power + _AssertInputIsScalar(op, 4) # beta2_power + _AssertInputIsScalar(op, 5) # lr + _AssertInputIsScalar(op, 6) # beta1 + _AssertInputIsScalar(op, 7) # beta2 + _AssertInputIsScalar(op, 8) # epsilon + grad_shape = op.inputs[9].get_shape().merge_with(v_shape) + return [grad_shape] + + +@ops.RegisterShape("ApplyMomentum") +def _ApplyMomentumShape(op): + """Shape function for the ApplyMomentum op.""" + var_shape = op.inputs[0].get_shape() + accum_shape = op.inputs[1].get_shape().merge_with(var_shape) + _AssertInputIsScalar(op, 2) # lr + grad_shape = op.inputs[3].get_shape().merge_with(accum_shape) + _AssertInputIsScalar(op, 4) # momentum + return [grad_shape] + + +@ops.RegisterShape("ApplyRMSProp") +def _ApplyRMSPropShape(op): + """Shape function for the ApplyRMSProp op.""" + var_shape = op.inputs[0].get_shape() + ms_shape = op.inputs[1].get_shape().merge_with(var_shape) + mom_shape = op.inputs[2].get_shape().merge_with(ms_shape) + _AssertInputIsScalar(op, 3) # lr + _AssertInputIsScalar(op, 4) # rho + _AssertInputIsScalar(op, 5) # momentum + _AssertInputIsScalar(op, 6) # epsilon + grad_shape = op.inputs[7].get_shape().merge_with(mom_shape) + return [grad_shape] + + +@ops.RegisterShape("ApplyGradientDescent") +def _ApplyGradientDescentShape(op): + """Shape function for the ApplyGradientDescent op.""" + var_shape = op.inputs[0].get_shape() + _AssertInputIsScalar(op, 1) # alpha + delta_shape = op.inputs[2].get_shape().merge_with(var_shape) + return [delta_shape] + + +@ops.RegisterShape("SparseApplyAdagrad") +def _SparseApplyAdagradShape(op): + """Shape function for the SparseApplyAdagrad op.""" + var_shape = op.inputs[0].get_shape() + accum_shape = op.inputs[1].get_shape().merge_with(var_shape) + _AssertInputIsScalar(op, 2) # lr + grad_shape = op.inputs[3].get_shape().merge_with( + tensor_shape.TensorShape([None]).concatenate(accum_shape[1:])) + unused_indices_shape = op.inputs[4].get_shape().merge_with( + tensor_shape.vector(grad_shape[0])) + return [accum_shape] + + +@ops.RegisterShape("SparseApplyMomentum") +def _SparseApplyMomentumShape(op): + """Shape function for the SparseApplyMomentum op.""" + var_shape = op.inputs[0].get_shape() + accum_shape = op.inputs[1].get_shape().merge_with(var_shape) + _AssertInputIsScalar(op, 2) # lr + grad_shape = op.inputs[3].get_shape().merge_with( + tensor_shape.TensorShape([None]).concatenate(accum_shape[1:])) + unused_indices_shape = op.inputs[4].get_shape().merge_with( + tensor_shape.vector(grad_shape[0])) + _AssertInputIsScalar(op, 5) # momentum + return [accum_shape] |