1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
"""Python wrappers for training ops."""
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.training import gen_training_ops
# pylint: disable=wildcard-import
from tensorflow.python.training.gen_training_ops import *
# pylint: enable=wildcard-import
# Shape functions for fused training ops
# --------------------------------------
#
# The fused training ops all have the same basic structure: they take
# one or more variables with the same shape, and emit a reference to
# the original variable (which has the same shape as the first
# input). In addition, they take one or more scalar tensors containing
# hyperparameters.
#
# The sparse ops take the gradients as a Python IndexedSlices, which
# means that the indices are a vector of length N, and the gradient
# values are a tensor whose size is the same as the original variable,
# except for the 0th dimension, which has size N.
def _AssertInputIsScalar(op, index):
"""Raises ValueError if `op.inputs[index]` is not scalar."""
op.inputs[index].get_shape().assert_is_compatible_with(tensor_shape.scalar())
@ops.RegisterShape("ApplyAdagrad")
def _ApplyAdagradShape(op):
"""Shape function for the ApplyAdagrad op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
return [grad_shape]
@ops.RegisterShape("ApplyAdam")
def _ApplyAdamShape(op):
"""Shape function for the ApplyAdam op."""
var_shape = op.inputs[0].get_shape()
m_shape = op.inputs[1].get_shape().merge_with(var_shape)
v_shape = op.inputs[2].get_shape().merge_with(m_shape)
_AssertInputIsScalar(op, 3) # beta1_power
_AssertInputIsScalar(op, 4) # beta2_power
_AssertInputIsScalar(op, 5) # lr
_AssertInputIsScalar(op, 6) # beta1
_AssertInputIsScalar(op, 7) # beta2
_AssertInputIsScalar(op, 8) # epsilon
grad_shape = op.inputs[9].get_shape().merge_with(v_shape)
return [grad_shape]
@ops.RegisterShape("ApplyMomentum")
def _ApplyMomentumShape(op):
"""Shape function for the ApplyMomentum op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
grad_shape = op.inputs[3].get_shape().merge_with(accum_shape)
_AssertInputIsScalar(op, 4) # momentum
return [grad_shape]
@ops.RegisterShape("ApplyRMSProp")
def _ApplyRMSPropShape(op):
"""Shape function for the ApplyRMSProp op."""
var_shape = op.inputs[0].get_shape()
ms_shape = op.inputs[1].get_shape().merge_with(var_shape)
mom_shape = op.inputs[2].get_shape().merge_with(ms_shape)
_AssertInputIsScalar(op, 3) # lr
_AssertInputIsScalar(op, 4) # rho
_AssertInputIsScalar(op, 5) # momentum
_AssertInputIsScalar(op, 6) # epsilon
grad_shape = op.inputs[7].get_shape().merge_with(mom_shape)
return [grad_shape]
@ops.RegisterShape("ApplyGradientDescent")
def _ApplyGradientDescentShape(op):
"""Shape function for the ApplyGradientDescent op."""
var_shape = op.inputs[0].get_shape()
_AssertInputIsScalar(op, 1) # alpha
delta_shape = op.inputs[2].get_shape().merge_with(var_shape)
return [delta_shape]
@ops.RegisterShape("SparseApplyAdagrad")
def _SparseApplyAdagradShape(op):
"""Shape function for the SparseApplyAdagrad op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
grad_shape = op.inputs[3].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
unused_indices_shape = op.inputs[4].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
return [accum_shape]
@ops.RegisterShape("SparseApplyMomentum")
def _SparseApplyMomentumShape(op):
"""Shape function for the SparseApplyMomentum op."""
var_shape = op.inputs[0].get_shape()
accum_shape = op.inputs[1].get_shape().merge_with(var_shape)
_AssertInputIsScalar(op, 2) # lr
grad_shape = op.inputs[3].get_shape().merge_with(
tensor_shape.TensorShape([None]).concatenate(accum_shape[1:]))
unused_indices_shape = op.inputs[4].get_shape().merge_with(
tensor_shape.vector(grad_shape[0]))
_AssertInputIsScalar(op, 5) # momentum
return [accum_shape]
|