aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2/adagrad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/optimizer_v2/adagrad.py')
-rw-r--r--tensorflow/contrib/optimizer_v2/adagrad.py13
1 files changed, 6 insertions, 7 deletions
diff --git a/tensorflow/contrib/optimizer_v2/adagrad.py b/tensorflow/contrib/optimizer_v2/adagrad.py
index c333d1e089..25ec475499 100644
--- a/tensorflow/contrib/optimizer_v2/adagrad.py
+++ b/tensorflow/contrib/optimizer_v2/adagrad.py
@@ -64,18 +64,17 @@ class AdagradOptimizer(optimizer_v2.OptimizerV2):
def _create_vars(self, var_list, state):
for v in var_list:
- # TODO(isaprykin): Delete colocate_with(v) from other optimizers and
- # confirm that colocation will happen anyway.
dtype = v.dtype.base_dtype
if v.get_shape().is_fully_defined():
init = init_ops.constant_initializer(self._initial_accumulator_value,
dtype=dtype)
else:
- # Use a Tensor instead of initializer if variable does not have static
- # shape.
- init_constant = gen_array_ops.fill(
- array_ops.shape(v), self._initial_accumulator_value)
- init = math_ops.cast(init_constant, dtype)
+ def init(v=v, dtype=dtype):
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
"accumulator")