aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/training_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/training_ops.py')
-rw-r--r--tensorflow/python/training/training_ops.py115
1 files changed, 115 insertions, 0 deletions
diff --git a/tensorflow/python/training/training_ops.py b/tensorflow/python/training/training_ops.py
new file mode 100644
index 0000000000..410b23e04d
--- /dev/null
+++ b/tensorflow/python/training/training_ops.py
@@ -0,0 +1,115 @@
+"""Python wrappers for training ops."""
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.training import gen_training_ops
+# pylint: disable=wildcard-import
+from tensorflow.python.training.gen_training_ops import *
+# pylint: enable=wildcard-import
+
+
+# Shape functions for fused training ops
+# --------------------------------------
+#
+# The fused training ops all have the same basic structure: they take
+# one or more variables with the same shape, and emit a reference to
+# the original variable (which has the same shape as the first
+# input). In addition, they take one or more scalar tensors containing
+# hyperparameters.
+#
+# The sparse ops take the gradients as a Python IndexedSlices, which
+# means that the indices are a vector of length N, and the gradient
+# values are a tensor whose size is the same as the original variable,
+# except for the 0th dimension, which has size N.
+
+
+def _AssertInputIsScalar(op, index):
+ """Raises ValueError if `op.inputs[index]` is not scalar."""
+ op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())
+
+
+@ops.RegisterShape("ApplyAdagrad")
+def _ApplyAdagradShape(op):
+ """Shape function for the ApplyAdagrad op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyAdam")
+def _ApplyAdamShape(op):
+ """Shape function for the ApplyAdam op."""
+ var_shape = op.inputs[0].get_shape()
+ m_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ v_shape = op.inputs[2].get_shape().merge_with(m_shape)
+ _AssertInputIsScalar(op, 3) # beta1_power
+ _AssertInputIsScalar(op, 4) # beta2_power
+ _AssertInputIsScalar(op, 5) # lr
+ _AssertInputIsScalar(op, 6) # beta1
+ _AssertInputIsScalar(op, 7) # beta2
+ _AssertInputIsScalar(op, 8) # epsilon
+ grad_shape = op.inputs[9].get_shape().merge_with(v_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyMomentum")
+def _ApplyMomentumShape(op):
+ """Shape function for the ApplyMomentum op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
+ _AssertInputIsScalar(op, 4) # momentum
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyRMSProp")
+def _ApplyRMSPropShape(op):
+ """Shape function for the ApplyRMSProp op."""
+ var_shape = op.inputs[0].get_shape()
+ ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
+ _AssertInputIsScalar(op, 3) # lr
+ _AssertInputIsScalar(op, 4) # rho
+ _AssertInputIsScalar(op, 5) # momentum
+ _AssertInputIsScalar(op, 6) # epsilon
+ grad_shape = op.inputs[7].get_shape().merge_with(mom_shape)
+ return [grad_shape]
+
+
+@ops.RegisterShape("ApplyGradientDescent")
+def _ApplyGradientDescentShape(op):
+ """Shape function for the ApplyGradientDescent op."""
+ var_shape = op.inputs[0].get_shape()
+ _AssertInputIsScalar(op, 1) # alpha
+ delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
+ return [delta_shape]
+
+
+@ops.RegisterShape("SparseApplyAdagrad")
+def _SparseApplyAdagradShape(op):
+ """Shape function for the SparseApplyAdagrad op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(
+ tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
+ unused_indices_shape = op.inputs[4].get_shape().merge_with(
+ tensor_shape.vector(grad_shape[0]))
+ return [accum_shape]
+
+
+@ops.RegisterShape("SparseApplyMomentum")
+def _SparseApplyMomentumShape(op):
+ """Shape function for the SparseApplyMomentum op."""
+ var_shape = op.inputs[0].get_shape()
+ accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
+ _AssertInputIsScalar(op, 2) # lr
+ grad_shape = op.inputs[3].get_shape().merge_with(
+ tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
+ unused_indices_shape = op.inputs[4].get_shape().merge_with(
+ tensor_shape.vector(grad_shape[0]))
+ _AssertInputIsScalar(op, 5) # momentum
+ return [accum_shape]