aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/training_ops.py
blob: 410b23e04d94b5958f46892dc26c83e1c3272017 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Python wrappers for training ops."""

from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.training import gen_training_ops
# pylint: disable=wildcard-import
from tensorflow.python.training.gen_training_ops import *
# pylint: enable=wildcard-import


# Shape functions for fused training ops
# --------------------------------------
#
# The fused training ops all have the same basic structure: they take
# one or more variables with the same shape, and emit a reference to
# the original variable (which has the same shape as the first
# input). In addition, they take one or more scalar tensors containing
# hyperparameters.
#
# The sparse ops take the gradients as a Python IndexedSlices, which
# means that the indices are a vector of length N, and the gradient
# values are a tensor whose size is the same as the original variable,
# except for the 0th dimension, which has size N.


def _AssertInputIsScalar(op, index):
  """Raises ValueError if `op.inputs[index]` is not scalar."""
  op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())


@ops.RegisterShape("ApplyAdagrad")
def _ApplyAdagradShape(op):
  """Shape function for the ApplyAdagrad op."""
  var_shape = op.inputs[0].get_shape()
  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
  _AssertInputIsScalar(op, 2)  # lr
  grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
  return [grad_shape]


@ops.RegisterShape("ApplyAdam")
def _ApplyAdamShape(op):
  """Shape function for the ApplyAdam op."""
  var_shape = op.inputs[0].get_shape()
  m_shape = op.inputs[1].get_shape().merge_with(var_shape)
  v_shape = op.inputs[2].get_shape().merge_with(m_shape)
  _AssertInputIsScalar(op, 3)  # beta1_power
  _AssertInputIsScalar(op, 4)  # beta2_power
  _AssertInputIsScalar(op, 5)  # lr
  _AssertInputIsScalar(op, 6)  # beta1
  _AssertInputIsScalar(op, 7)  # beta2
  _AssertInputIsScalar(op, 8)  # epsilon
  grad_shape = op.inputs[9].get_shape().merge_with(v_shape)
  return [grad_shape]


@ops.RegisterShape("ApplyMomentum")
def _ApplyMomentumShape(op):
  """Shape function for the ApplyMomentum op."""
  var_shape = op.inputs[0].get_shape()
  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
  _AssertInputIsScalar(op, 2)  # lr
  grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
  _AssertInputIsScalar(op, 4)  # momentum
  return [grad_shape]


@ops.RegisterShape("ApplyRMSProp")
def _ApplyRMSPropShape(op):
  """Shape function for the ApplyRMSProp op."""
  var_shape = op.inputs[0].get_shape()
  ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
  mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
  _AssertInputIsScalar(op, 3)  # lr
  _AssertInputIsScalar(op, 4)  # rho
  _AssertInputIsScalar(op, 5)  # momentum
  _AssertInputIsScalar(op, 6)  # epsilon
  grad_shape = op.inputs[7].get_shape().merge_with(mom_shape)
  return [grad_shape]


@ops.RegisterShape("ApplyGradientDescent")
def _ApplyGradientDescentShape(op):
  """Shape function for the ApplyGradientDescent op."""
  var_shape = op.inputs[0].get_shape()
  _AssertInputIsScalar(op, 1)  # alpha
  delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
  return [delta_shape]


@ops.RegisterShape("SparseApplyAdagrad")
def _SparseApplyAdagradShape(op):
  """Shape function for the SparseApplyAdagrad op."""
  var_shape = op.inputs[0].get_shape()
  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
  _AssertInputIsScalar(op, 2)  # lr
  grad_shape = op.inputs[3].get_shape().merge_with(
      tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
  unused_indices_shape = op.inputs[4].get_shape().merge_with(
      tensor_shape.vector(grad_shape[0]))
  return [accum_shape]


@ops.RegisterShape("SparseApplyMomentum")
def _SparseApplyMomentumShape(op):
  """Shape function for the SparseApplyMomentum op."""
  var_shape = op.inputs[0].get_shape()
  accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
  _AssertInputIsScalar(op, 2)  # lr
  grad_shape = op.inputs[3].get_shape().merge_with(
      tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
  unused_indices_shape = op.inputs[4].get_shape().merge_with(
      tensor_shape.vector(grad_shape[0]))
  _AssertInputIsScalar(op, 5)  # momentum
  return [accum_shape]